예제 #1
0
파일: swap.py 프로젝트: KSRawal/fsgan
    def __init__(
        self,
        resolution=d('resolution'),
        crop_scale=d('crop_scale'),
        gpus=d('gpus'),
        cpu_only=d('cpu_only'),
        display=d('display'),
        verbose=d('verbose'),
        encoder_codec=d('encoder_codec'),
        # Detection arguments:
        detection_model=d('detection_model'),
        det_batch_size=d('det_batch_size'),
        det_postfix=d('det_postfix'),
        # Sequence arguments:
        iou_thresh=d('iou_thresh'),
        min_length=d('min_length'),
        min_size=d('min_size'),
        center_kernel=d('center_kernel'),
        size_kernel=d('size_kernel'),
        smooth_det=d('smooth_det'),
        seq_postfix=d('seq_postfix'),
        write_empty=d('write_empty'),
        # Pose arguments:
        pose_model=d('pose_model'),
        pose_batch_size=d('pose_batch_size'),
        pose_postfix=d('pose_postfix'),
        cache_pose=d('cache_pose'),
        cache_frontal=d('cache_frontal'),
        smooth_poses=d('smooth_poses'),
        # Landmarks arguments:
        lms_model=d('lms_model'),
        lms_batch_size=d('lms_batch_size'),
        landmarks_postfix=d('landmarks_postfix'),
        cache_landmarks=d('cache_landmarks'),
        smooth_landmarks=d('smooth_landmarks'),
        # Segmentation arguments:
        seg_model=d('seg_model'),
        smooth_segmentation=d('smooth_segmentation'),
        segmentation_postfix=d('segmentation_postfix'),
        cache_segmentation=d('cache_segmentation'),
        seg_batch_size=d('seg_batch_size'),
        seg_remove_mouth=d('seg_remove_mouth'),
        # Finetune arguments:
        finetune=d('finetune'),
        finetune_iterations=d('finetune_iterations'),
        finetune_lr=d('finetune_lr'),
        finetune_batch_size=d('finetune_batch_size'),
        finetune_workers=d('finetune_workers'),
        finetune_save=d('finetune_save'),
        # Swapping arguments:
        batch_size=d('batch_size'),
        reenactment_model=d('reenactment_model'),
        completion_model=d('completion_model'),
        blending_model=d('blending_model'),
        criterion_id=d('criterion_id'),
        min_radius=d('min_radius'),
        output_crop=d('output_crop'),
        renderer_process=d('renderer_process')):
        super(FaceSwapping,
              self).__init__(resolution,
                             crop_scale,
                             gpus,
                             cpu_only,
                             display,
                             verbose,
                             encoder_codec,
                             detection_model=detection_model,
                             det_batch_size=det_batch_size,
                             det_postfix=det_postfix,
                             iou_thresh=iou_thresh,
                             min_length=min_length,
                             min_size=min_size,
                             center_kernel=center_kernel,
                             size_kernel=size_kernel,
                             smooth_det=smooth_det,
                             seq_postfix=seq_postfix,
                             write_empty=write_empty,
                             pose_model=pose_model,
                             pose_batch_size=pose_batch_size,
                             pose_postfix=pose_postfix,
                             cache_pose=True,
                             cache_frontal=cache_frontal,
                             smooth_poses=smooth_poses,
                             lms_model=lms_model,
                             lms_batch_size=lms_batch_size,
                             landmarks_postfix=landmarks_postfix,
                             cache_landmarks=True,
                             smooth_landmarks=smooth_landmarks,
                             seg_model=seg_model,
                             seg_batch_size=seg_batch_size,
                             segmentation_postfix=segmentation_postfix,
                             cache_segmentation=True,
                             smooth_segmentation=smooth_segmentation,
                             seg_remove_mouth=seg_remove_mouth)
        self.batch_size = batch_size
        self.min_radius = min_radius
        self.output_crop = output_crop
        self.finetune_enabled = finetune
        self.finetune_iterations = finetune_iterations
        self.finetune_lr = finetune_lr
        self.finetune_batch_size = finetune_batch_size
        self.finetune_workers = finetune_workers
        self.finetune_save = finetune_save

        # Load reenactment model
        self.Gr, checkpoint = load_model(reenactment_model,
                                         'face reenactment',
                                         self.device,
                                         return_checkpoint=True)
        self.Gr.arch = checkpoint['arch']
        self.reenactment_state_dict = checkpoint['state_dict']

        # Load all other models
        self.Gc = load_model(completion_model, 'face completion', self.device)
        self.Gb = load_model(blending_model, 'face blending', self.device)

        # Initialize landmarks decoders
        self.landmarks_decoders = []
        for res in (128, 256):
            self.landmarks_decoders.insert(
                0,
                LandmarksHeatMapDecoder(res).to(self.device))

        # Initialize losses
        self.criterion_pixelwise = nn.L1Loss().to(self.device)
        self.criterion_id = obj_factory(criterion_id).to(self.device)

        # Support multiple GPUs
        if self.gpus and len(self.gpus) > 1:
            self.Gr = nn.DataParallel(self.Gr, self.gpus)
            self.Gc = nn.DataParallel(self.Gc, self.gpus)
            self.Gb = nn.DataParallel(self.Gb, self.gpus)
            self.criterion_id.vgg = nn.DataParallel(self.criterion_id.vgg,
                                                    self.gpus)

        # Initialize soft erosion
        self.smooth_mask = SoftErosion(kernel_size=21,
                                       threshold=0.6).to(self.device)

        # Initialize video writer
        self.video_renderer = FaceSwappingRenderer(
            self.display, self.verbose, self.output_crop, self.resolution,
            self.crop_scale, encoder_codec, renderer_process)
        self.video_renderer.start()
예제 #2
0
    def __init__(
        self,
        resolution=d('resolution'),
        crop_scale=d('crop_scale'),
        gpus=d('gpus'),
        cpu_only=d('cpu_only'),
        display=d('display'),
        verbose=d('verbose'),
        # Detection arguments:
        detection_model=d('detection_model'),
        det_batch_size=d('det_batch_size'),
        det_postfix=d('det_postfix'),
        # Sequence arguments:
        iou_thresh=d('iou_thresh'),
        min_length=d('min_length'),
        min_size=d('min_size'),
        center_kernel=d('center_kernel'),
        size_kernel=d('size_kernel'),
        smooth_det=d('smooth_det'),
        seq_postfix=d('seq_postfix'),
        write_empty=d('write_empty'),
        # Pose arguments:
        pose_model=d('pose_model'),
        pose_batch_size=d('pose_batch_size'),
        pose_postfix=d('pose_postfix'),
        cache_pose=d('cache_pose'),
        cache_frontal=d('cache_frontal'),
        smooth_poses=d('smooth_poses'),
        # Landmarks arguments:
        lms_model=d('lms_model'),
        lms_batch_size=d('lms_batch_size'),
        landmarks_postfix=d('landmarks_postfix'),
        cache_landmarks=d('cache_landmarks'),
        smooth_landmarks=d('smooth_landmarks'),
        # Segmentation arguments:
        seg_model=d('seg_model'),
        seg_batch_size=d('seg_batch_size'),
        segmentation_postfix=d('segmentation_postfix'),
        cache_segmentation=d('cache_segmentation'),
        smooth_segmentation=d('smooth_segmentation'),
        seg_remove_mouth=d('seg_remove_mouth')):
        # General
        self.resolution = resolution
        self.crop_scale = crop_scale
        self.display = display
        self.verbose = verbose

        # Detection
        self.face_detector = FaceDetector(det_postfix, detection_model, gpus,
                                          det_batch_size, display)
        self.det_postfix = det_postfix

        # Sequences
        self.iou_thresh = iou_thresh
        self.min_length = min_length
        self.min_size = min_size
        self.center_kernel = center_kernel
        self.size_kernel = size_kernel
        self.smooth_det = smooth_det
        self.seq_postfix = seq_postfix
        self.write_empty = write_empty

        # Pose
        self.pose_batch_size = pose_batch_size
        self.pose_postfix = pose_postfix
        self.cache_pose = cache_pose
        self.cache_frontal = cache_frontal
        self.smooth_poses = smooth_poses

        # Landmarks
        self.smooth_landmarks = smooth_landmarks
        self.landmarks_postfix = landmarks_postfix
        self.cache_landmarks = cache_landmarks
        self.lms_batch_size = lms_batch_size

        # Segmentation
        self.smooth_segmentation = smooth_segmentation
        self.segmentation_postfix = segmentation_postfix
        self.cache_segmentation = cache_segmentation
        self.seg_batch_size = seg_batch_size
        self.seg_remove_mouth = seg_remove_mouth and cache_landmarks

        # Initialize device
        torch.set_grad_enabled(False)
        self.device, self.gpus = set_device(gpus, not cpu_only)

        # Load models
        self.face_pose = load_model(pose_model, 'face pose',
                                    self.device) if cache_pose else None
        self.L = load_model(lms_model, 'face landmarks',
                            self.device) if cache_landmarks else None
        self.S = load_model(seg_model, 'face segmentation',
                            self.device) if cache_segmentation else None

        # Initialize heatmap encoder
        self.heatmap_encoder = LandmarksHeatMapEncoder().to(self.device)

        # Initialize normalization tensors
        # Note: this is necessary because of the landmarks model
        self.img_mean = torch.as_tensor([0.5, 0.5, 0.5],
                                        device=self.device).view(1, 3, 1, 1)
        self.img_std = torch.as_tensor([0.5, 0.5, 0.5],
                                       device=self.device).view(1, 3, 1, 1)
        self.context_mean = torch.as_tensor([0.485, 0.456, 0.406],
                                            device=self.device).view(
                                                1, 3, 1, 1)
        self.context_std = torch.as_tensor([0.229, 0.224, 0.225],
                                           device=self.device).view(
                                               1, 3, 1, 1)

        # Support multiple GPUs
        if self.gpus and len(self.gpus) > 1:
            self.face_pose = nn.DataParallel(
                self.face_pose,
                self.gpus) if self.face_pose is not None else None
            self.L = nn.DataParallel(self.L,
                                     self.gpus) if self.L is not None else None
            self.S = nn.DataParallel(self.S,
                                     self.gpus) if self.S is not None else None

        # Initialize temportal smoothing
        if smooth_segmentation > 0:
            self.smooth_seg = TemporalSmoothing(3, smooth_segmentation).to(
                self.device)
        else:
            self.smooth_seg = None

        # Initialize output videos format
        self.fourcc = cv2.VideoWriter_fourcc(*'avc1')