def face_detection_loop(vid: cv2.VideoCapture, face_detector: RetinaFacePredictor,
                        landmark_detector: FANPredictor, window_title: str) \
        -> Tuple[Union[np.ndarray, None], Union[np.ndarray, None]]:
    print(
        'Face and landmark detection started, you can use the following commands:\n'
        + '  |_C: Capture the selected face for pose augmentation.\n'
        '  |_Q: Quit the demo.')
    while True:
        _, frame = vid.read()

        # Face and landmark detection
        faces = face_detector(frame, rgb=False)
        landmarks, scores = landmark_detector(frame, faces, rgb=False)

        # Try to select a face
        face_sizes = [(bbox[2] - bbox[0]) *
                      (bbox[3] - bbox[1]) if scs.min() >= 0.2 else -1
                      for bbox, scs in zip(faces, scores)]
        selected_face = np.argmin(
            face_sizes) if len(face_sizes) > 0 and min(face_sizes) > 0 else -1

        # Visualisation
        for idx, (lms, scs) in enumerate(zip(landmarks, scores)):
            if idx != selected_face:
                plot_landmarks(frame,
                               lms,
                               scs,
                               line_colour=(192, 192, 192),
                               pts_colour=(128, 128, 128))
        if selected_face < 0:
            frame_vis = frame
        else:
            frame_vis = frame.copy()
            plot_landmarks(frame_vis, landmarks[selected_face],
                           scores[selected_face])

        # Show the frame and process commands
        cv2.imshow(window_title, frame_vis)
        key = cv2.waitKey(1) % 2**16
        if key == ord('c') or key == ord('C'):
            if selected_face < 0:
                print('\'C\' pressed, but there is no face being selected.')
            else:
                print(
                    '\'C\' pressed, applying pose augmentation to the selected face.'
                )
                return frame, landmarks[selected_face]
        elif key == ord('q') or key == ord('Q'):
            print('\'Q\' pressed, we are done here.')
            return None, None
def main() -> None:
    # Parse command-line arguments
    parser = ArgumentParser()
    parser.add_argument('--input',
                        '-i',
                        help='Input video path or webcam index',
                        default=0)
    parser.add_argument('--output',
                        '-o',
                        help='Output file path',
                        default=None)
    parser.add_argument('--benchmark',
                        '-b',
                        help='Enable benchmark mode for CUDNN',
                        action='store_true',
                        default=False)
    parser.add_argument('--no-display',
                        '-n',
                        help='No display if processing a video file',
                        action='store_true',
                        default=False)

    parser.add_argument(
        '--detection-threshold',
        '-dt',
        type=float,
        default=0.8,
        help='Confidence threshold for face detection (default=0.8)')
    parser.add_argument(
        '--detection-method',
        '-dm',
        default='retinaface',
        help=
        'Face detection method, can be either RatinaFace or S3FD (default=RatinaFace)'
    )
    parser.add_argument(
        '--detection-weights',
        '-dw',
        default=None,
        help='Weights to be loaded for face detection, ' +
        'can be either resnet50 or mobilenet0.25 when using RetinaFace')
    parser.add_argument(
        '--detection-device',
        '-dd',
        default='cuda:0',
        help='Device to be used for face detection (default=cuda:0)')

    parser.add_argument(
        '--alignment-threshold',
        '-at',
        type=float,
        default=0.2,
        help=
        'Score threshold used when visualising detected landmarks (default=0.2)'
    ),
    parser.add_argument('--alignment-method',
                        '-am',
                        default='fan',
                        help='Face alignment method, must be set to FAN')
    parser.add_argument(
        '--alignment-weights',
        '-aw',
        default=None,
        help=
        'Weights to be loaded for face alignment, can be either 2DFAN2 or 2DFAN4'
    )
    parser.add_argument(
        '--alignment-device',
        '-ad',
        default='cuda:0',
        help='Device to be used for face alignment (default=cuda:0)')

    parser.add_argument(
        '--tddfa-weights',
        '-tw',
        default=None,
        help='Weights to be loaded by 3DDFA, must be set to mobilenet1')
    parser.add_argument('--tddfa-device',
                        '-td',
                        default='cuda:0',
                        help='Device to be used by 3DDFA.')
    args = parser.parse_args()

    # Set benchmark mode flag for CUDNN
    torch.backends.cudnn.benchmark = args.benchmark

    vid = None
    out_vid = None
    has_window = False
    try:
        # Create the face detector
        args.detection_method = args.detection_method.lower()
        if args.detection_method == 'retinaface':
            face_detector = RetinaFacePredictor(
                threshold=args.detection_threshold,
                device=args.detection_device,
                model=(RetinaFacePredictor.get_model(args.detection_weights)
                       if args.detection_weights else None))
            print('Face detector created using RetinaFace.')
        elif args.detection_method == 's3fd':
            face_detector = S3FDPredictor(
                threshold=args.detection_threshold,
                device=args.detection_device,
                model=(S3FDPredictor.get_model(args.detection_weights)
                       if args.detection_weights else None))
            print('Face detector created using S3FD.')
        else:
            raise ValueError(
                'detector-method must be set to either RetinaFace or S3FD')

        # Create the landmark detector
        args.alignment_method = args.alignment_method.lower()
        if args.alignment_method == 'fan':
            landmark_detector = FANPredictor(
                device=args.alignment_device,
                model=(FANPredictor.get_model(args.alignment_weights)
                       if args.alignment_weights else None))
            print('Landmark detector created using FAN.')
        else:
            raise ValueError('alignment-method must be set to FAN')

        # Instantiate 3DDFA
        tddfa = TDDFAPredictor(
            device=args.tddfa_device,
            model=(TDDFAPredictor.get_model(args.tddfa_weights)
                   if args.tddfa_weights else None))
        print('3DDFA initialised.')

        # Open the input video
        using_webcam = not os.path.exists(args.input)
        vid = cv2.VideoCapture(int(args.input) if using_webcam else args.input)
        assert vid.isOpened()
        if using_webcam:
            print(f'Webcam #{int(args.input)} opened.')
        else:
            print(f'Input video "{args.input}" opened.')

        # Open the output video (if a path is given)
        if args.output is not None:
            out_vid = cv2.VideoWriter(
                args.output,
                apiPreference=cv2.CAP_FFMPEG,
                fps=vid.get(cv2.CAP_PROP_FPS),
                frameSize=(int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)),
                           int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))),
                fourcc=cv2.VideoWriter_fourcc('m', 'p', '4', 'v'))

        # Process the frames
        frame_number = 0
        window_title = os.path.splitext(os.path.basename(__file__))[0]
        print('Processing started, press \'Q\' to quit.')
        while True:
            # Get a new frame
            _, frame = vid.read()
            if frame is None:
                break
            else:
                # Detect faces
                start_time = time.time()
                faces = face_detector(frame, rgb=False)
                current_time = time.time()
                elapsed_time = current_time - start_time

                # Face alignment
                start_time = current_time
                landmarks, scores = landmark_detector(frame, faces, rgb=False)
                current_time = time.time()
                elapsed_time2 = current_time - start_time

                ss = time.time()
                lala = TDDFAPredictor.decode(
                    tddfa(frame, landmarks, rgb=False, two_steps=True))
                print(time.time() - ss)

                # Textural output
                print(
                    f'Frame #{frame_number} processed in {elapsed_time * 1000.0:.04f} + '
                    +
                    f'{elapsed_time2 * 1000.0:.04f} ms: {len(faces)} faces analysed..'
                )

                # Rendering
                for face, yy in zip(faces, lala):
                    bbox = face[:4].astype(int)
                    cv2.rectangle(frame, (bbox[0], bbox[1]),
                                  (bbox[2], bbox[3]),
                                  color=(0, 0, 255),
                                  thickness=2)
                    lm = tddfa.project_vertex(yy, False)
                    plot_landmarks(frame, lm[:, :2])
                    if len(face) > 5:
                        plot_landmarks(frame,
                                       face[5:].reshape((-1, 2)),
                                       pts_radius=3)

                # Write the frame to output video (if recording)
                if out_vid is not None:
                    out_vid.write(frame)

                # Display the frame
                if using_webcam or not args.no_display:
                    has_window = True
                    cv2.imshow(window_title, frame)
                    key = cv2.waitKey(1) % 2**16
                    if key == ord('q') or key == ord('Q'):
                        print('\'Q\' pressed, we are done here.')
                        break
                frame_number += 1
    finally:
        if has_window:
            cv2.destroyAllWindows()
        if out_vid is not None:
            out_vid.release()
        if vid is not None:
            vid.release()
        print('All done.')
def main() -> None:
    # Parse command-line arguments
    parser = ArgumentParser()
    parser.add_argument('--input',
                        '-i',
                        help='Input video path or webcam index (default=0)',
                        default=0)
    parser.add_argument('--output',
                        '-o',
                        help='Output file path',
                        default=None)
    parser.add_argument('--fourcc',
                        '-f',
                        help='FourCC of the output video (default=mp4v)',
                        type=str,
                        default='mp4v')
    parser.add_argument('--benchmark',
                        '-b',
                        help='Enable benchmark mode for CUDNN',
                        action='store_true',
                        default=False)
    parser.add_argument('--no-display',
                        '-n',
                        help='No display if processing a video file',
                        action='store_true',
                        default=False)

    parser.add_argument(
        '--detection-threshold',
        '-dt',
        type=float,
        default=0.8,
        help='Confidence threshold for face detection (default=0.8)')
    parser.add_argument(
        '--detection-method',
        '-dm',
        default='retinaface',
        help=
        'Face detection method, can be either RatinaFace or S3FD (default=RatinaFace)'
    )
    parser.add_argument(
        '--detection-weights',
        '-dw',
        default=None,
        help='Weights to be loaded for face detection, ' +
        'can be either resnet50 or mobilenet0.25 when using RetinaFace')
    parser.add_argument(
        '--detection-alternative-pth',
        '-dp',
        default=None,
        help='Alternative pth file to be loaded for face detection')
    parser.add_argument(
        '--detection-device',
        '-dd',
        default='cuda:0',
        help='Device to be used for face detection (default=cuda:0)')

    parser.add_argument(
        '--alignment-threshold',
        '-at',
        type=float,
        default=0.2,
        help=
        'Score threshold used when visualising detected landmarks (default=0.2)'
    ),
    parser.add_argument('--alignment-method',
                        '-am',
                        default='fan',
                        help='Face alignment method, must be set to FAN')
    parser.add_argument(
        '--alignment-weights',
        '-aw',
        default=None,
        help=
        'Weights to be loaded for face alignment, can be either 2DFAN2, 2DFAN4, '
        + 'or 2DFAN2_ALT')
    parser.add_argument(
        '--alignment-alternative-pth',
        '-ap',
        default=None,
        help='Alternative pth file to be loaded for face alaignment')
    parser.add_argument(
        '--alignment-device',
        '-ad',
        default='cuda:0',
        help='Device to be used for face alignment (default=cuda:0)')

    parser.add_argument(
        '--emotion-method',
        '-em',
        default='emonet',
        help='Emotion recognition method, must be set to EmoNet')
    parser.add_argument(
        '--emotion-weights',
        '-ew',
        default=None,
        help='Weights to be loaded for emotion recognition, can be either ' +
        'EmoNet248, EmoNet245, EmoNet248_alt, or EmoNet245_alt')
    parser.add_argument(
        '--emotion-alternative-pth',
        '-ep',
        default=None,
        help='Alternative pth file to be loaded for emotion recognition')
    parser.add_argument(
        '--emotion-device',
        '-ed',
        default='cuda:0',
        help='Device to be used for emotion recognition (default=cuda:0)')
    args = parser.parse_args()

    # Set benchmark mode flag for CUDNN
    torch.backends.cudnn.benchmark = args.benchmark

    vid = None
    out_vid = None
    has_window = False
    try:
        # Create the face detector
        args.detection_method = args.detection_method.lower()
        if args.detection_method == 'retinaface':
            face_detector_class = (RetinaFacePredictor, 'RetinaFace')
        elif args.detection_method == 's3fd':
            face_detector_class = (S3FDPredictor, 'S3FD')
        else:
            raise ValueError(
                'detector-method must be set to either RetinaFace or S3FD')
        if args.detection_weights is None:
            fd_model = face_detector_class[0].get_model()
        else:
            fd_model = face_detector_class[0].get_model(args.detection_weights)
        if args.detection_alternative_pth is not None:
            fd_model.weights = args.detection_alternative_pth
        face_detector = face_detector_class[0](
            threshold=args.detection_threshold,
            device=args.detection_device,
            model=fd_model)
        print(
            f"Face detector created using {face_detector_class[1]} ({fd_model.weights})."
        )

        # Create the landmark detector
        args.alignment_method = args.alignment_method.lower()
        if args.alignment_method == 'fan':
            if args.alignment_weights is None:
                fa_model = FANPredictor.get_model()
            else:
                fa_model = FANPredictor.get_model(args.alignment_weights)
            if args.alignment_alternative_pth is not None:
                fa_model.weights = args.alignment_alternative_pth
            landmark_detector = FANPredictor(device=args.alignment_device,
                                             model=fa_model)
            print(f"Landmark detector created using FAN ({fa_model.weights}).")
        else:
            raise ValueError('alignment-method must be set to FAN')

        # Create the emotion recogniser
        args.emotion_method = args.emotion_method.lower()
        if args.emotion_method == 'emonet':
            if args.emotion_weights is None:
                er_model = EmoNetPredictor.get_model()
            else:
                er_model = EmoNetPredictor.get_model(args.emotion_weights)
            if args.emotion_alternative_pth is not None:
                er_model.weights = args.emotion_alternative_pth
            emotion_recogniser = EmoNetPredictor(device=args.emotion_device,
                                                 model=er_model)
            print(
                f"Emotion recogniser created using EmoNet ({er_model.weights})."
            )
        else:
            raise ValueError('emotion-method must be set to EmoNet')

        # Open the input video
        using_webcam = not os.path.exists(args.input)
        vid = cv2.VideoCapture(int(args.input) if using_webcam else args.input)
        assert vid.isOpened()
        if using_webcam:
            print(f'Webcam #{int(args.input)} opened.')
        else:
            print(f'Input video "{args.input}" opened.')

        # Open the output video (if a path is given)
        if args.output is not None:
            out_vid = cv2.VideoWriter(
                args.output,
                fps=vid.get(cv2.CAP_PROP_FPS),
                frameSize=(int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)),
                           int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))),
                fourcc=cv2.VideoWriter_fourcc(*args.fourcc))
            assert out_vid.isOpened()

        # Process the frames
        frame_number = 0
        if len(emotion_recogniser.config.emotion_labels) == 8:
            emotion_colours = ((192, 192, 192), (0, 255, 0), (255, 0, 0),
                               (0, 255, 255), (0, 128, 255), (255, 0, 128),
                               (0, 0, 255), (128, 255, 0))
        else:
            emotion_colours = ((192, 192, 192), (0, 255, 0), (255, 0, 0),
                               (0, 255, 255), (0, 0, 255))
        window_title = os.path.splitext(os.path.basename(__file__))[0]
        print('Processing started, press \'Q\' to quit.')
        while True:
            # Get a new frame
            _, frame = vid.read()
            if frame is None:
                break
            else:
                # Detect faces
                start_time = time.time()
                faces = face_detector(frame, rgb=False)
                current_time = time.time()
                elapsed_time = current_time - start_time

                # Face alignment
                start_time = current_time
                landmarks, scores, fan_features = landmark_detector(
                    frame, faces, rgb=False, return_features=True)
                current_time = time.time()
                elapsed_time2 = current_time - start_time

                # Emotion recognition
                start_time = current_time
                emotions = emotion_recogniser(fan_features)
                current_time = time.time()
                elapsed_time3 = current_time - start_time

                # Textural output
                print(
                    f'Frame #{frame_number} processed in {elapsed_time * 1000.0:.04f} + '
                    +
                    f'{elapsed_time2 * 1000.0:.04f} + {elapsed_time3 * 1000.0:.04f} ms: '
                    + f'{len(faces)} faces analysed.')

                # Rendering
                for idx, (face, lm,
                          sc) in enumerate(zip(faces, landmarks, scores)):
                    bbox = face[:4].astype(int)
                    cv2.rectangle(frame, (bbox[0], bbox[1]),
                                  (bbox[2], bbox[3]),
                                  color=(0, 0, 255),
                                  thickness=2)
                    plot_landmarks(frame,
                                   lm,
                                   sc,
                                   threshold=args.alignment_threshold)
                    if len(face) > 5:
                        plot_landmarks(frame,
                                       face[5:].reshape((-1, 2)),
                                       pts_radius=3)
                    emo = emotion_recogniser.config.emotion_labels[
                        emotions['emotion'][idx]].title()
                    val, ar = emotions['valence'][idx], emotions['arousal'][
                        idx]
                    text_content = f'{emo} ({val: .01f}, {ar: .01f})'
                    cv2.putText(frame,
                                text_content, (bbox[0], bbox[1] - 10),
                                cv2.FONT_HERSHEY_DUPLEX,
                                0.5,
                                emotion_colours[emotions['emotion'][idx]],
                                lineType=cv2.LINE_AA)

                # Write the frame to output video (if recording)
                if out_vid is not None:
                    out_vid.write(frame)

                # Display the frame
                if using_webcam or not args.no_display:
                    has_window = True
                    cv2.imshow(window_title, frame)
                    key = cv2.waitKey(1) % 2**16
                    if key == ord('q') or key == ord('Q'):
                        print('\'Q\' pressed, we are done here.')
                        break
                frame_number += 1
    finally:
        if has_window:
            cv2.destroyAllWindows()
        if out_vid is not None:
            out_vid.release()
        if vid is not None:
            vid.release()
        print('All done.')
def face_pose_augmentation_loop(tddfa: TDDFAPredictor,
                                augmentor: FacePoseAugmentor,
                                frame: np.ndarray, landmarks: np.ndarray,
                                landmark_style_index: int,
                                window_title: str) -> int:
    # Apply 3DDFA
    start_time = time.time()
    tddfa_result = TDDFAPredictor.decode(tddfa(frame, landmarks, rgb=False))[0]
    pitch, yaw, roll = np.array(
        [tddfa_result['face_pose'][k]
         for k in ('pitch', 'yaw', 'roll')]) / np.pi * 180.0
    print(
        f'3D face model fitted in {(time.time() - start_time) * 1000.0:.3f} ms.'
    )
    print(
        f'The estimated head pose (pitch, yaw, and roll, in degree) is ({pitch:.3f}, {yaw:.3f}, {roll:.3f})'
    )

    # Determine delta poses
    delta_poses = []
    delta_pitchs = np.arange(-20, 21, 10)
    delta_yaws = np.arange(0, -90 - yaw, -10) if yaw < 0 else np.arange(
        0, 90 - yaw, 10)
    for dp in delta_pitchs:
        for dy in delta_yaws:
            delta_poses.append((dp, dy, 0))
    delta_poses = np.array(delta_poses) / 180.0 * np.pi

    # Pose augmentation
    start_time = time.time()
    augmentation_results = augmentor(frame, tddfa_result, delta_poses,
                                     landmarks)
    print(
        f'Pose augmentation finished in {(time.time() - start_time):.3f} second.'
    )

    # Display the result
    dp_idx, dy_idx = len(delta_pitchs) // 2, 0
    landmark_styles = ['3d_style', '2d_style', 'projected_3d', 'refined_2d']
    if landmark_style_index > 0:
        print(
            f'Displaying result with \'{landmark_styles[landmark_style_index - 1]}\' landmarks, '
            + 'you can use the following commands:')
    else:
        print(
            'Displaying result with no landmarks, you can use the following commands:'
        )
    print('  |_A: Turn left (decrease yaw).\n' +
          '  |_D: Turn right (increase yaw).\n' +
          '  |_W: Tilt up (decrease pitch).\n' +
          '  |_S: Tilt down (increase pitch).\n' +
          '  |_0: Do not display landmarks.\n' +
          f'  |_1: Display \'{landmark_styles[0]}\' landmarks\n' +
          f'  |_2: Display \'{landmark_styles[1]}\' landmarks\n' +
          f'  |_3: Display \'{landmark_styles[2]}\' landmarks\n' +
          f'  |_4: Display \'{landmark_styles[3]}\' landmarks\n'
          '  |_C: Goes back to face and landmark detection.\n' +
          '  |_Q: Quit the demo.')
    while True:
        result = augmentation_results[dp_idx * len(delta_yaws) + dy_idx]
        if landmark_style_index > 0:
            frame_vis = result['warped_image'].copy()
            plot_landmarks(
                frame_vis,
                result['warped_landmarks'][landmark_styles[landmark_style_index
                                                           - 1]][:, :2])
        else:
            frame_vis = result['warped_image']
        cv2.imshow(window_title, frame_vis)
        key = cv2.waitKey(0) % 2**16
        if key == ord('a') or key == ord('A'):
            dy_idx = min(dy_idx + 1,
                         len(delta_yaws) -
                         1) if yaw < 0 else max(0, dy_idx - 1)
            print(
                f'\'A\' pressed: turing left by setting head pose to ({pitch:.3f}, {yaw:.3f}, {roll:.3f}) + '
                +
                f'({delta_pitchs[dp_idx]:.1f}, {delta_yaws[dy_idx]:.1f}, 0.0)')
        elif key == ord('d') or key == ord('D'):
            dy_idx = max(0, dy_idx -
                         1) if yaw < 0 else min(dy_idx + 1,
                                                len(delta_yaws) - 1)
            print(
                f'\'D\' pressed: turing right by setting head pose to ({pitch:.3f}, {yaw:.3f}, {roll:.3f}) + '
                +
                f'({delta_pitchs[dp_idx]:.1f}, {delta_yaws[dy_idx]:.1f}, 0.0)')
        elif key == ord('w') or key == ord('W'):
            dp_idx = max(0, dp_idx - 1)
            print(
                f'\'W\' pressed: tilting up by setting head pose to ({pitch:.3f}, {yaw:.3f}, {roll:.3f}) + '
                +
                f'({delta_pitchs[dp_idx]:.1f}, {delta_yaws[dy_idx]:.1f}, 0.0)')
        elif key == ord('s') or key == ord('S'):
            dp_idx = min(dp_idx + 1, len(delta_pitchs) - 1)
            print(
                f'\'S\' pressed: tilting down by setting head pose to ({pitch:.3f}, {yaw:.3f}, {roll:.3f}) + '
                +
                f'({delta_pitchs[dp_idx]:.1f}, {delta_yaws[dy_idx]:.1f}, 0.0)')
        elif ord('0') <= key <= ord('4'):
            landmark_style_index = key - ord('0')
            if landmark_style_index > 0:
                print(
                    f'\'{chr(key)}\' pressed, setting to display ' +
                    f'\'{landmark_styles[landmark_style_index - 1]}\' landmarks.'
                )
            else:
                print(
                    f'\'{chr(key)}\' pressed, setting to not display landmarks.'
                )
        elif key == ord('c') or key == ord('C'):
            print('\'C\' pressed, going back to face and landmark detection.')
            return landmark_style_index
        elif key == ord('q') or key == ord('Q'):
            print('\'Q\' pressed, we are done here.')
            return -1