Esempio n. 1
0
    def __getitem__(self, index):
        """ Sample the dataset at the given index.

        Args:
            index (int): Frame index

        Returns:
            Tuple of np.array: A tuple containing the current video frame and additional meta-data in the order
                specified by postfixes.
        """
        if self.cap is None:
            # Open video file
            self.cap = cv2.VideoCapture(self.vid_path)

        ret, frame_bgr = self.cap.read()
        assert frame_bgr is not None, 'Failed to read frame from video in index: %d' % index
        frame_rgb = frame_bgr[:, :, ::-1]

        # Add additional data
        data = [frame_rgb]
        if len(self.data) > 0:
            for d in self.data:
                if isinstance(d[index], bytes):
                    data.append(decode_binary_mask(d[index]))
                else:
                    data.append(d[index])
            # data += [d[index] for d in self.data]

        # Apply transformation
        if self.transform is not None:
            data = self.transform(data)

        return tuple(data) if len(data) > 1 else data[0]
Esempio n. 2
0
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image1, image2, target) where target is True for same identity else False.
        """
        if self.src_vid is None:
            # Open source video on the data loader's process
            self.src_vid = cv2.VideoCapture(self.src_vid_seq_path)
        if self.tgt_vid is None:
            # Open target video on the data loader's process
            self.tgt_vid = cv2.VideoCapture(self.tgt_vid_seq_path)

        # Read next target frame and meta-data
        ret, tgt_frame_bgr = self.tgt_vid.read()
        assert tgt_frame_bgr is not None, 'Failed to read frame from video in index: %d' % index
        tgt_frame = tgt_frame_bgr[:, :, ::-1]
        tgt_landmarks = self.tgt_landmarks[index]
        tgt_pose = self.tgt_poses[index]
        tgt_seg = decode_binary_mask(self.tgt_encoded_seg[index])

        # Query source frames and meta-data given the current target pose
        query_point, tilt_angle = tgt_pose[:2], tgt_pose[2]
        tri_index = self.tri.find_simplex(query_point[:2])
        tri_vertices = self.tri.simplices[tri_index]
        tri_vertices = np.minimum(tri_vertices, self.valid_size)

        # Compute barycentric weights
        b = self.tri.transform[tri_index, :2].dot(query_point[:2] - self.tri.transform[tri_index, 2])
        bw = np.array([b[0], b[1], 1 - b.sum()], dtype='float32')
        bw[tri_vertices >= self.valid_size] = 0.    # Set zero weight for edge points
        bw /= bw.sum()

        # Cache source frames
        for tv in np.sort(tri_vertices):
            if self.src_frames[tv] is None:
                self.src_vid.set(cv2.CAP_PROP_POS_FRAMES, self.filtered_indices[tv])
                ret, frame_bgr = self.src_vid.read()
                assert frame_bgr is not None, 'Failed to read frame from source video in index: %d' % tv
                frame_rgb = frame_bgr[:, :, ::-1]
                self.src_frames[tv] = frame_rgb

        # Get source data from appearance map
        src_frames = [self.src_frames[tv] for tv in tri_vertices]
        src_landmarks = self.src_landmarks[tri_vertices].astype('float32')
        src_poses = self.src_poses[tri_vertices].astype('float32')

        # Apply source transformation
        if self.src_transform is not None:
            src_data = [(src_frames[i], src_landmarks[i], (src_poses[i][2] - tilt_angle) * 99.)
                        for i in range(len(src_frames))]
            src_data = self.src_transform(src_data)
            src_landmarks = torch.stack([src_data[i][1] for i in range(len(src_data))])
            src_frames = [src_data[i][0] for i in range(len(src_data))]
            src_poses[:, 2] = tilt_angle

        # Apply target transformation
        if self.tgt_transform is not None:
            tgt_frame = self.tgt_transform(tgt_frame)

        # Combine pyramids in source frames if they exist
        if isinstance(src_frames[0], (list, tuple)):
            src_frames = [torch.stack([src_frames[f][p] for f in range(len(src_frames))], dim=0)
                          for p in range(len(src_frames[0]))]

        return src_frames, src_landmarks, src_poses, bw, tgt_frame, tgt_landmarks, tgt_pose, tgt_seg
Esempio n. 3
0
    def query(self, vid_index, seq_index, frame_index):
        """
        Args:
            vid_index (int): Index of the original video
            seq_index (int): Index of the video sequence
            frame_index (int): Index of the frame corresponding to the video sequence

        Returns:
            (np.array, ..., int (optional)): Tuple containing:
                - tuple of np.array: Sampled data corresponding to the specified postfixes
                - int, optional: The target corresponding to the original video if ``target_list`` was specified
        """
        target = self.targets[vid_index] if self.targets is not None else None
        all_seq_paths = self.file_paths[vid_index]
        seq_paths = all_seq_paths[seq_index]
        frame_index = [frame_index] if not isinstance(frame_index, (list, tuple)) else frame_index

        # For each sequence path
        data = []
        for seq_path in seq_paths:
            seq_queries = []
            if seq_path.endswith('.mp4'):
                # Open video
                vid = cv2.VideoCapture(seq_path)

                # For each frame index
                for fi in frame_index:
                    vid.set(cv2.CAP_PROP_POS_FRAMES, fi)

                    # Read the frames from the video
                    frame_list = []
                    for i in range(self.frame_window):
                        ret, frame_bgr = vid.read()
                        assert frame_bgr is not None, 'Failed to read frame from video: "%s"' % seq_path
                        frame_rgb = frame_bgr[:, :, ::-1]
                        frame_list.append(frame_rgb)
                    seq_queries.append(frame_list if self.frame_window > 1 else frame_list[0])
            elif seq_path.endswith('_lms.npz'):
                landmarks = np.load(seq_path)['landmarks']
                for fi in frame_index:
                    landmarks_window = landmarks[fi:fi + self.frame_window]
                    seq_queries.append(landmarks_window if self.frame_window > 1 else landmarks_window[0])
            elif seq_path.endswith('_pose.npz'):
                poses = np.load(seq_path)['poses']
                for fi in frame_index:
                    poses_window = poses[fi:fi + self.frame_window]
                    seq_queries.append(poses_window if self.frame_window > 1 else poses_window[0])
            elif seq_path.endswith('_seg.pkl'):
                segmentations = np.load(seq_path, allow_pickle=True)
                for fi in frame_index:
                    segmentations_window = segmentations[fi:fi + self.frame_window]
                    segmentations_window = [decode_binary_mask(s) for s in segmentations_window]
                    seq_queries.append(segmentations_window if self.frame_window > 1 else segmentations_window[0])
            else:
                raise RuntimeError('Unknown file type: "%s"' % seq_path)
            data.append(seq_queries if len(frame_index) > 1 else seq_queries[0])

        # Apply transformation
        if self.transform is not None:
            data = self.transform(data)

        if target is None:
            return tuple(data)
        else:
            return tuple(data) + (target,)