class LoadImageFromFile(object): """Load image from file. Args: io_backend (str): io backend where images are store. Default: 'disk'. key (str): Keys in results to find corresponding path. Default: 'gt'. flag (str): Loading flag for images. Default: 'color'. channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'. Default: 'bgr'. save_original_img (bool): If True, maintain a copy of the image in `results` dict with name of `f'ori_{key}'`. Default: False. kwargs (dict): Args for file client. """ def __init__(self, io_backend='disk', key='gt', flag='color', channel_order='bgr', save_original_img=False, **kwargs): self.io_backend = io_backend self.key = key self.flag = flag self.save_original_img = save_original_img self.channel_order = channel_order self.kwargs = kwargs self.file_client = None def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) filepath = str(results[f'{self.key}_path']) img_bytes = self.file_client.get(filepath) img = mmcv.imfrombytes(img_bytes, flag=self.flag, channel_order=self.channel_order) # HWC results[self.key] = img results[f'{self.key}_path'] = filepath results[f'{self.key}_ori_shape'] = img.shape if self.save_original_img: results[f'ori_{self.key}'] = img.copy() return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += ( f'(io_backend={self.io_backend}, key={self.key}, ' f'flag={self.flag}, save_original_img={self.save_original_img})') return repr_str
class DecordInit(object): """Using decord to initialize the video_reader. Decord: https://github.com/dmlc/decord Required keys are "filename", added or modified keys are "video_reader" and "total_frames". """ def __init__(self, io_backend='disk', num_threads=1, **kwargs): self.io_backend = io_backend self.num_threads = num_threads self.kwargs = kwargs self.file_client = None def __call__(self, results): """Perform the PyAV loading. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ try: import decord except ImportError: raise ImportError( 'Please run "pip install decord" to install Decord first.') if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) file_obj = io.BytesIO(self.file_client.get(results['filename'])) container = decord.VideoReader(file_obj, num_threads=self.num_threads) results['video_reader'] = container results['total_frames'] = len(container) return results
class LoadImageFromFileList(LoadImageFromFile): """Load image from file list. It accepts a list of path and read each frame from each path. A list of frames will be returned. Args: io_backend (str): io backend where images are store. Default: 'disk'. key (str): Keys in results to find corresponding path. Default: 'gt'. flag (str): Loading flag for images. Default: 'color'. channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'. Default: 'bgr'. save_original_img (bool): If True, maintain a copy of the image in `results` dict with name of `f'ori_{key}'`. Default: False. kwargs (dict): Args for file client. """ def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) filepaths = results[f'{self.key}_path'] if not isinstance(filepaths, list): raise TypeError( f'filepath should be list, but got {type(filepaths)}') filepaths = [str(v) for v in filepaths] imgs = [] shapes = [] if self.save_original_img: ori_imgs = [] for filepath in filepaths: img_bytes = self.file_client.get(filepath) img = mmcv.imfrombytes( img_bytes, flag=self.flag, channel_order=self.channel_order) # HWC if img.ndim == 2: img = np.expand_dims(img, axis=2) imgs.append(img) shapes.append(img.shape) if self.save_original_img: ori_imgs.append(img.copy()) results[self.key] = imgs results[f'{self.key}_path'] = filepaths results[f'{self.key}_ori_shape'] = shapes if self.save_original_img: results[f'ori_{self.key}'] = ori_imgs return results
class AudioDecodeInit(object): """Using librosa to initialize the audio reader. Args: io_backend (str): io backend where frames are store. Default: 'disk'. sample_rate (int): Audio sampling times per second. Default: 16000. Required keys are "audio_path", added or modified keys are "length", "sample_rate", "audios". """ def __init__(self, io_backend='disk', sample_rate=16000, pad_method='zero', **kwargs): self.io_backend = io_backend self.sample_rate = sample_rate if pad_method in ['random', 'zero']: self.pad_method = pad_method else: raise NotImplementedError self.kwargs = kwargs self.file_client = None def _zero_pad(self, shape): return np.zeros(shape, dtype=np.float32) def _random_pad(self, shape): # librosa load raw audio file into a distribution of -1~+1 return np.random.rand(shape).astype(np.float32) * 2 - 1 def __call__(self, results): """Perform the librosa initialization. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ try: import librosa except ImportError: raise ImportError('Please install librosa first.') if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) if osp.exists(results['audio_path']): file_obj = io.BytesIO(self.file_client.get(results['audio_path'])) y, sr = librosa.load(file_obj, sr=self.sample_rate) else: # Generate a random dummy 10s input pad_func = getattr(self, f'_{self.pad_method}_pad') y = pad_func(int(round(10.0 * self.sample_rate))) sr = self.sample_rate results['length'] = y.shape[0] results['sample_rate'] = sr results['audios'] = y return results
def load_fileclient_dist(filename, backend, map_location): """In distributed setting, this function only download checkpoint at local rank 0.""" rank, world_size = get_dist_info() rank = int(os.environ.get('LOCAL_RANK', rank)) allowed_backends = ['ceph'] if backend not in allowed_backends: raise ValueError(f'Load from Backend {backend} is not supported.') if rank == 0: fileclient = FileClient(backend=backend) buffer = io.BytesIO(fileclient.get(filename)) checkpoint = torch.load(buffer, map_location=map_location) if world_size > 1: torch.distributed.barrier() if rank > 0: fileclient = FileClient(backend=backend) buffer = io.BytesIO(fileclient.get(filename)) checkpoint = torch.load(buffer, map_location=map_location) return checkpoint
class RandomLoadResizeBg: """Randomly load a background image and resize it. Required key is "fg", added key is "bg". Args: bg_dir (str): Path of directory to load background images from. io_backend (str): io backend where images are store. Default: 'disk'. flag (str): Loading flag for images. Default: 'color'. channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'. Default: 'bgr'. kwargs (dict): Args for file client. """ def __init__(self, bg_dir, io_backend='disk', flag='color', channel_order='bgr', **kwargs): self.bg_dir = bg_dir self.bg_list = list(mmcv.scandir(bg_dir)) self.io_backend = io_backend self.flag = flag self.channel_order = channel_order self.kwargs = kwargs self.file_client = None def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) h, w = results['fg'].shape[:2] idx = np.random.randint(len(self.bg_list)) filepath = Path(self.bg_dir).joinpath(self.bg_list[idx]) img_bytes = self.file_client.get(filepath) img = mmcv.imfrombytes(img_bytes, flag=self.flag, channel_order=self.channel_order) # HWC bg = mmcv.imresize(img, (w, h), interpolation='bicubic') results['bg'] = bg return results def __repr__(self): return self.__class__.__name__ + f"(bg_dir='{self.bg_dir}')"
class PyAVInit(object): """Using pyav to initialize the video. PyAV: https://github.com/mikeboers/PyAV Required keys are "filename", added or modified keys are "video_reader", and "total_frames". Args: io_backend (str): io backend where frames are store. Default: 'disk'. kwargs (dict): Args for file client. """ def __init__(self, io_backend='disk', **kwargs): self.io_backend = io_backend self.kwargs = kwargs self.file_client = None def __call__(self, results): """Perform the PyAV initiation. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ try: import av except ImportError: raise ImportError('Please run "conda install av -c conda-forge" ' 'or "pip install av" to install PyAV first.') if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) file_obj = io.BytesIO(self.file_client.get(results['filename'])) container = av.open(file_obj) results['video_reader'] = container # results['total_frames'] = container.streams.video[0].frames try: try: results['total_frames'] = container.streams.video[0].frames except KeyError: print(results['filename']) except IndexError: print(results['filename']) return results
class OpenCVInit(object): """Using OpenCV to initalize the video_reader. Required keys are "filename", added or modified keys are "new_path", "video_reader" and "total_frames". """ def __init__(self, io_backend='disk', **kwargs): self.io_backend = io_backend self.kwargs = kwargs self.file_client = None random_string = get_random_string() thread_id = get_thread_id() self.tmp_folder = osp.join(get_shm_dir(), f'{random_string}_{thread_id}') os.mkdir(self.tmp_folder) def __call__(self, results): """Perform the OpenCV initiation. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ if self.io_backend == 'disk': new_path = results['filename'] else: if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) thread_id = get_thread_id() # save the file of same thread at the same place new_path = osp.join(self.tmp_folder, f'tmp_{thread_id}.mp4') with open(new_path, 'wb') as f: f.write(self.file_client.get(results['filename'])) container = mmcv.VideoReader(new_path) results['new_path'] = new_path results['video_reader'] = container results['total_frames'] = len(container) return results def __del__(self): shutil.rmtree(self.tmp_folder)
class CompositeFg: """Composite foreground with a random foreground. This class composites the current training sample with additional data randomly (could be from the same dataset). With probability 0.5, the sample will be composited with a random sample from the specified directory. The composition is performed as: .. math:: fg_{new} = \\alpha_1 * fg_1 + (1 - \\alpha_1) * fg_2 \\alpha_{new} = 1 - (1 - \\alpha_1) * (1 - \\alpha_2) where :math:`(fg_1, \\alpha_1)` is from the current sample and :math:`(fg_2, \\alpha_2)` is the randomly loaded sample. With the above composition, :math:`\\alpha_{new}` is still in `[0, 1]`. Required keys are "alpha" and "fg". Modified keys are "alpha" and "fg". Args: fg_dirs (str | list[str]): Path of directories to load foreground images from. alpha_dirs (str | list[str]): Path of directories to load alpha mattes from. interpolation (str): Interpolation method of `mmcv.imresize` to resize the randomly loaded images. """ def __init__(self, fg_dirs, alpha_dirs, interpolation='nearest', io_backend='disk', **kwargs): self.fg_dirs = fg_dirs if isinstance(fg_dirs, list) else [fg_dirs] self.alpha_dirs = alpha_dirs if isinstance(alpha_dirs, list) else [alpha_dirs] self.interpolation = interpolation self.fg_list, self.alpha_list = self._get_file_list( self.fg_dirs, self.alpha_dirs) self.io_backend = io_backend self.file_client = None self.kwargs = kwargs def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) fg = results['fg'] alpha = results['alpha'].astype(np.float32) / 255. h, w = results['fg'].shape[:2] # randomly select fg if np.random.rand() < 0.5: idx = np.random.randint(len(self.fg_list)) fg2_bytes = self.file_client.get(self.fg_list[idx]) fg2 = mmcv.imfrombytes(fg2_bytes) alpha2_bytes = self.file_client.get(self.alpha_list[idx]) alpha2 = mmcv.imfrombytes(alpha2_bytes, flag='grayscale') alpha2 = alpha2.astype(np.float32) / 255. fg2 = mmcv.imresize(fg2, (w, h), interpolation=self.interpolation) alpha2 = mmcv.imresize(alpha2, (w, h), interpolation=self.interpolation) # the overlap of two 50% transparency will be 75% alpha_tmp = 1 - (1 - alpha) * (1 - alpha2) # if the result alpha is all-one, then we avoid composition if np.any(alpha_tmp < 1): # composite fg with fg2 fg = fg.astype(np.float32) * alpha[..., None] \ + fg2.astype(np.float32) * (1 - alpha[..., None]) alpha = alpha_tmp fg.astype(np.uint8) results['fg'] = fg results['alpha'] = (alpha * 255).astype(np.uint8) return results @staticmethod def _get_file_list(fg_dirs, alpha_dirs): all_fg_list = list() all_alpha_list = list() for fg_dir, alpha_dir in zip(fg_dirs, alpha_dirs): fg_list = sorted(mmcv.scandir(fg_dir)) alpha_list = sorted(mmcv.scandir(alpha_dir)) # we assume the file names for fg and alpha are the same assert len(fg_list) == len(alpha_list), ( f'{fg_dir} and {alpha_dir} should have the same number of ' f'images ({len(fg_list)} differs from ({len(alpha_list)})') fg_list = [osp.join(fg_dir, fg) for fg in fg_list] alpha_list = [osp.join(alpha_dir, alpha) for alpha in alpha_list] all_fg_list.extend(fg_list) all_alpha_list.extend(alpha_list) return all_fg_list, all_alpha_list def __repr__(self): repr_str = self.__class__.__name__ repr_str += (f'(fg_dirs={self.fg_dirs}, alpha_dirs={self.alpha_dirs}, ' f"interpolation='{self.interpolation}')") return repr_str
class LoadPairedImageFromFile(LoadImageFromFile): """Load a pair of images from file. Each sample contains a pair of images, which are concatenated in the w dimension (a|b). This is a special loading class for generation paired dataset. It loads a pair of images as the common loader does and crops it into two images with the same shape in different domains. Required key is "pair_path". Added or modified keys are "pair", "pair_ori_shape", "ori_pair", "img_a", "img_b", "img_a_path", "img_b_path", "img_a_ori_shape", "img_b_ori_shape", "ori_img_a" and "ori_img_b". Args: io_backend (str): io backend where images are store. Default: 'disk'. key (str): Keys in results to find corresponding path. Default: 'gt'. flag (str): Loading flag for images. Default: 'color'. channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'. Default: 'bgr'. save_original_img (bool): If True, maintain a copy of the image in `results` dict with name of `f'ori_{key}'`. Default: False. kwargs (dict): Args for file client. """ def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) filepath = str(results[f'{self.key}_path']) img_bytes = self.file_client.get(filepath) img = mmcv.imfrombytes(img_bytes, flag=self.flag) # HWC, BGR if img.ndim == 2: img = np.expand_dims(img, axis=2) results[self.key] = img results[f'{self.key}_path'] = filepath results[f'{self.key}_ori_shape'] = img.shape if self.save_original_img: results[f'ori_{self.key}'] = img.copy() # crop pair into a and b w = img.shape[1] if w % 2 != 0: raise ValueError( f'The width of image pair must be even number, but got {w}.') new_w = w // 2 img_a = img[:, :new_w, :] img_b = img[:, new_w:, :] results['img_a'] = img_a results['img_b'] = img_b results['img_a_path'] = filepath results['img_b_path'] = filepath results['img_a_ori_shape'] = img_a.shape results['img_b_ori_shape'] = img_b.shape if self.save_original_img: results['ori_img_a'] = img_a.copy() results['ori_img_b'] = img_b.copy() return results
class LoadMask(object): """Load Mask for multiple types. For different types of mask, users need to provide the corresponding config dict. Example config for bbox: .. code-block:: python config = dict(img_shape=(256, 256), max_bbox_shape=128) Example config for irregular: .. code-block:: python config = dict( img_shape=(256, 256), num_vertexes=(4, 12), max_angle=4., length_range=(10, 100), brush_width=(10, 40), area_ratio_range=(0.15, 0.5)) Example config for ff: .. code-block:: python config = dict( img_shape=(256, 256), num_vertexes=(4, 12), mean_angle=1.2, angle_range=0.4, brush_width=(12, 40)) Example config for set: .. code-block:: python config = dict( mask_list_file='xxx/xxx/ooxx.txt', prefix='/xxx/xxx/ooxx/', io_backend='disk', flag='unchanged', file_client_kwargs=dict() ) The mask_list_file contains the list of mask file name like this: test1.jpeg test2.jpeg ... ... The prefix gives the data path. Args: mask_mode (str): Mask mode in ['bbox', 'irregular', 'ff', 'set', 'file']. * bbox: square bounding box masks. * irregular: irregular holes. * ff: free-form holes from DeepFillv2. * set: randomly get a mask from a mask set. * file: get mask from 'mask_path' in results. mask_config (dict): Params for creating masks. Each type of mask needs different configs. """ def __init__(self, mask_mode='bbox', mask_config=None): self.mask_mode = mask_mode self.mask_config = dict() if mask_config is None else mask_config assert isinstance(self.mask_config, dict) # set init info if needed in some modes self._init_info() def _init_info(self): if self.mask_mode == 'set': # get mask list information self.mask_list = [] mask_list_file = self.mask_config['mask_list_file'] with open(mask_list_file, 'r') as f: for line in f: line_split = line.strip().split(' ') mask_name = line_split[0] self.mask_list.append( Path(self.mask_config['prefix']).joinpath(mask_name)) self.mask_set_size = len(self.mask_list) self.io_backend = self.mask_config['io_backend'] self.flag = self.mask_config['flag'] self.file_client_kwargs = self.mask_config['file_client_kwargs'] self.file_client = None elif self.mask_mode == 'file': self.io_backend = 'disk' self.flag = 'unchanged' self.file_client_kwargs = dict() self.file_client = None def _get_random_mask_from_set(self): if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.file_client_kwargs) # minus 1 to avoid out of range error mask_idx = np.random.randint(0, self.mask_set_size) mask_bytes = self.file_client.get(self.mask_list[mask_idx]) mask = mmcv.imfrombytes(mask_bytes, flag=self.flag) # HWC, BGR if mask.ndim == 2: mask = np.expand_dims(mask, axis=2) else: mask = mask[:, :, 0:1] mask[mask > 0] = 1. return mask def _get_mask_from_file(self, path): if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.file_client_kwargs) mask_bytes = self.file_client.get(path) mask = mmcv.imfrombytes(mask_bytes, flag=self.flag) # HWC, BGR if mask.ndim == 2: mask = np.expand_dims(mask, axis=2) else: mask = mask[:, :, 0:1] mask[mask > 0] = 1. return mask def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ if self.mask_mode == 'bbox': mask_bbox = random_bbox(**self.mask_config) mask = bbox2mask(self.mask_config['img_shape'], mask_bbox) results['mask_bbox'] = mask_bbox elif self.mask_mode == 'irregular': mask = get_irregular_mask(**self.mask_config) elif self.mask_mode == 'set': mask = self._get_random_mask_from_set() elif self.mask_mode == 'ff': mask = brush_stroke_mask(**self.mask_config) elif self.mask_mode == 'file': mask = self._get_mask_from_file(results['mask_path']) else: raise NotImplementedError( f'Mask mode {self.mask_mode} has not been implemented.') results['mask'] = mask return results def __repr__(self): return self.__class__.__name__ + f"(mask_mode='{self.mask_mode}')"
class LoadImageFromFile: """Load image from file. Args: io_backend (str): io backend where images are store. Default: 'disk'. key (str): Keys in results to find corresponding path. Default: 'gt'. flag (str): Loading flag for images. Default: 'color'. channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'. Default: 'bgr'. convert_to (str | None): The color space of the output image. If None, no conversion is conducted. Default: None. save_original_img (bool): If True, maintain a copy of the image in `results` dict with name of `f'ori_{key}'`. Default: False. use_cache (bool): If True, load all images at once. Default: False. backend (str): The image loading backend type. Options are `cv2`, `pillow`, and 'turbojpeg'. Default: None. kwargs (dict): Args for file client. """ def __init__(self, io_backend='disk', key='gt', flag='color', channel_order='bgr', convert_to=None, save_original_img=False, use_cache=False, backend=None, **kwargs): self.io_backend = io_backend self.key = key self.flag = flag self.save_original_img = save_original_img self.channel_order = channel_order self.convert_to = convert_to self.kwargs = kwargs self.file_client = None self.use_cache = use_cache self.cache = None self.backend = backend def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ filepath = str(results[f'{self.key}_path']) if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) if self.use_cache: if self.cache is None: self.cache = dict() if filepath in self.cache: img = self.cache[filepath] else: img_bytes = self.file_client.get(filepath) img = mmcv.imfrombytes(img_bytes, flag=self.flag, channel_order=self.channel_order, backend=self.backend) # HWC self.cache[filepath] = img else: img_bytes = self.file_client.get(filepath) img = mmcv.imfrombytes(img_bytes, flag=self.flag, channel_order=self.channel_order, backend=self.backend) # HWC if self.convert_to is not None: if self.channel_order == 'bgr' and self.convert_to.lower() == 'y': img = mmcv.bgr2ycbcr(img, y_only=True) elif self.channel_order == 'rgb': img = mmcv.rgb2ycbcr(img, y_only=True) else: raise ValueError('Currently support only "bgr2ycbcr" or ' '"bgr2ycbcr".') if img.ndim == 2: img = np.expand_dims(img, axis=2) results[self.key] = img results[f'{self.key}_path'] = filepath results[f'{self.key}_ori_shape'] = img.shape if self.save_original_img: results[f'ori_{self.key}'] = img.copy() return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += ( f'(io_backend={self.io_backend}, key={self.key}, ' f'flag={self.flag}, save_original_img={self.save_original_img}, ' f'channel_order={self.channel_order}, use_cache={self.use_cache})') return repr_str
class FrameSelector(object): """Select raw frames with given indices. Required keys are "frame_dir", "filename_tmpl" and "frame_inds", added or modified keys are "imgs", "img_shape" and "original_shape". Args: io_backend (str): IO backend where frames are stored. Default: 'disk'. decoding_backend (str): Backend used for image decoding. Default: 'cv2'. kwargs (dict, optional): Arguments for FileClient. """ def __init__(self, io_backend='disk', decoding_backend='cv2', **kwargs): self.io_backend = io_backend self.decoding_backend = decoding_backend self.kwargs = kwargs self.file_client = None def __call__(self, results): """Perform the FrameSelector selecting given indices. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ mmcv.use_backend(self.decoding_backend) directory = results['frame_dir'] filename_tmpl = results['filename_tmpl'] modality = results['modality'] if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) imgs = list() if results['frame_inds'].ndim != 1: results['frame_inds'] = np.squeeze(results['frame_inds']) offset = results.get('offset', 0) for frame_idx in results['frame_inds']: frame_idx += offset if modality == 'RGB': filepath = osp.join(directory, filename_tmpl.format(frame_idx)) img_bytes = self.file_client.get(filepath) # Get frame with channel order RGB directly. cur_frame = mmcv.imfrombytes(img_bytes, channel_order='rgb') imgs.append(cur_frame) elif modality == 'Flow': x_filepath = osp.join(directory, filename_tmpl.format('x', frame_idx)) y_filepath = osp.join(directory, filename_tmpl.format('y', frame_idx)) x_img_bytes = self.file_client.get(x_filepath) x_frame = mmcv.imfrombytes(x_img_bytes, flag='grayscale') y_img_bytes = self.file_client.get(y_filepath) y_frame = mmcv.imfrombytes(y_img_bytes, flag='grayscale') imgs.extend([x_frame, y_frame]) else: raise NotImplementedError results['imgs'] = imgs results['original_shape'] = imgs[0].shape[:2] results['img_shape'] = imgs[0].shape[:2] return results
class LoadKineticsPose: """Load Kinetics Pose given filename (The format should be pickle) Required keys are "filename", "total_frames", "img_shape", "frame_inds", "anno_inds" (for mmpose source, optional), added or modified keys are "keypoint", "keypoint_score". Args: io_backend (str): IO backend where frames are stored. Default: 'disk'. squeeze (bool): Whether to remove frames with no human pose. Default: True. max_person (int): The max number of persons in a frame. Default: 10. keypoint_weight (dict): The weight of keypoints. We set the confidence score of a person as the weighted sum of confidence scores of each joint. Persons with low confidence scores are dropped (if exceed max_person). Default: dict(face=1, torso=2, limb=3). source (str): The sources of the keypoints used. Choices are 'mmpose' and 'openpose'. Default: 'mmpose'. kwargs (dict, optional): Arguments for FileClient. """ def __init__(self, io_backend='disk', squeeze=True, max_person=100, keypoint_weight=dict(face=1, torso=2, limb=3), source='mmpose', **kwargs): self.io_backend = io_backend self.squeeze = squeeze self.max_person = max_person self.keypoint_weight = cp.deepcopy(keypoint_weight) self.source = source if source == 'openpose': self.kpsubset = dict( face=[0, 14, 15, 16, 17], torso=[1, 2, 8, 5, 11], limb=[3, 4, 6, 7, 9, 10, 12, 13]) elif source == 'mmpose': self.kpsubset = dict( face=[0, 1, 2, 3, 4], torso=[5, 6, 11, 12], limb=[7, 8, 9, 10, 13, 14, 15, 16]) else: raise NotImplementedError('Unknown source of Kinetics Pose') self.kwargs = kwargs self.file_client = None def __call__(self, results): assert 'filename' in results filename = results.pop('filename') # only applicable to source == 'mmpose' anno_inds = None if 'anno_inds' in results: assert self.source == 'mmpose' anno_inds = results.pop('anno_inds') results.pop('box_score', None) if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) bytes = self.file_client.get(filename) # only the kp array is in the pickle file, each kp include x, y, score. kps = pickle.loads(bytes) total_frames = results['total_frames'] frame_inds = results.pop('frame_inds') if anno_inds is not None: kps = kps[anno_inds] frame_inds = frame_inds[anno_inds] frame_inds = list(frame_inds) def mapinds(inds): uni = np.unique(inds) mapp = {x: i for i, x in enumerate(uni)} inds = [mapp[x] for x in inds] return np.array(inds, dtype=np.int16) if self.squeeze: frame_inds = mapinds(frame_inds) total_frames = np.max(frame_inds) + 1 # write it back results['total_frames'] = total_frames h, w = results['img_shape'] if self.source == 'openpose': kps[:, :, 0] *= w kps[:, :, 1] *= h num_kp = kps.shape[1] num_person = mode(frame_inds)[-1][0] new_kp = np.zeros([num_person, total_frames, num_kp, 2], dtype=np.float16) new_kpscore = np.zeros([num_person, total_frames, num_kp], dtype=np.float16) # 32768 is enough num_person_frame = np.zeros([total_frames], dtype=np.int16) for frame_ind, kp in zip(frame_inds, kps): person_ind = num_person_frame[frame_ind] new_kp[person_ind, frame_ind] = kp[:, :2] new_kpscore[person_ind, frame_ind] = kp[:, 2] num_person_frame[frame_ind] += 1 kpgrp = self.kpsubset weight = self.keypoint_weight results['num_person'] = num_person if num_person > self.max_person: for i in range(total_frames): np_frame = num_person_frame[i] val = new_kpscore[:np_frame, i] val = ( np.sum(val[:, kpgrp['face']], 1) * weight['face'] + np.sum(val[:, kpgrp['torso']], 1) * weight['torso'] + np.sum(val[:, kpgrp['limb']], 1) * weight['limb']) inds = sorted(range(np_frame), key=lambda x: -val[x]) new_kpscore[:np_frame, i] = new_kpscore[inds, i] new_kp[:np_frame, i] = new_kp[inds, i] results['num_person'] = self.max_person results['keypoint'] = new_kp[:self.max_person] results['keypoint_score'] = new_kpscore[:self.max_person] return results def __repr__(self): repr_str = (f'{self.__class__.__name__}(' f'io_backend={self.io_backend}, ' f'squeeze={self.squeeze}, ' f'max_person={self.max_person}, ' f'keypoint_weight={self.keypoint_weight}, ' f'source={self.source}, ' f'kwargs={self.kwargs})') return repr_str