Пример #1
class LoadImageFromFile(object):
    """Load image from file.

        io_backend (str): io backend where images are store. Default: 'disk'.
        key (str): Keys in results to find corresponding path. Default: 'gt'.
        flag (str): Loading flag for images. Default: 'color'.
        channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'.
            Default: 'bgr'.
        save_original_img (bool): If True, maintain a copy of the image in
            `results` dict with name of `f'ori_{key}'`. Default: False.
        kwargs (dict): Args for file client.
    def __init__(self,
        self.io_backend = io_backend
        self.key = key
        self.flag = flag
        self.save_original_img = save_original_img
        self.channel_order = channel_order
        self.kwargs = kwargs
        self.file_client = None

    def __call__(self, results):
        """Call function.

            results (dict): A dict containing the necessary information and
                data for augmentation.

            dict: A dict containing the processed data and information.
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend, **self.kwargs)
        filepath = str(results[f'{self.key}_path'])
        img_bytes = self.file_client.get(filepath)
        img = mmcv.imfrombytes(img_bytes,
                               channel_order=self.channel_order)  # HWC

        results[self.key] = img
        results[f'{self.key}_path'] = filepath
        results[f'{self.key}_ori_shape'] = img.shape
        if self.save_original_img:
            results[f'ori_{self.key}'] = img.copy()

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        repr_str += (
            f'(io_backend={self.io_backend}, key={self.key}, '
            f'flag={self.flag}, save_original_img={self.save_original_img})')
        return repr_str
Пример #2
class DecordInit(object):
    """Using decord to initialize the video_reader.

    Decord: https://github.com/dmlc/decord

    Required keys are "filename",
    added or modified keys are "video_reader" and "total_frames".
    def __init__(self, io_backend='disk', num_threads=1, **kwargs):
        self.io_backend = io_backend
        self.num_threads = num_threads
        self.kwargs = kwargs
        self.file_client = None

    def __call__(self, results):
        """Perform the PyAV loading.

            results (dict): The resulting dict to be modified and passed
                to the next transform in pipeline.
            import decord
        except ImportError:
            raise ImportError(
                'Please run "pip install decord" to install Decord first.')

        if self.file_client is None:
            self.file_client = FileClient(self.io_backend, **self.kwargs)

        file_obj = io.BytesIO(self.file_client.get(results['filename']))
        container = decord.VideoReader(file_obj, num_threads=self.num_threads)
        results['video_reader'] = container
        results['total_frames'] = len(container)
        return results
Пример #3
class LoadImageFromFileList(LoadImageFromFile):
    """Load image from file list.

    It accepts a list of path and read each frame from each path. A list
    of frames will be returned.

        io_backend (str): io backend where images are store. Default: 'disk'.
        key (str): Keys in results to find corresponding path. Default: 'gt'.
        flag (str): Loading flag for images. Default: 'color'.
        channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'.
            Default: 'bgr'.
        save_original_img (bool): If True, maintain a copy of the image in
            `results` dict with name of `f'ori_{key}'`. Default: False.
        kwargs (dict): Args for file client.

    def __call__(self, results):
        """Call function.

            results (dict): A dict containing the necessary information and
                data for augmentation.

            dict: A dict containing the processed data and information.

        if self.file_client is None:
            self.file_client = FileClient(self.io_backend, **self.kwargs)
        filepaths = results[f'{self.key}_path']
        if not isinstance(filepaths, list):
            raise TypeError(
                f'filepath should be list, but got {type(filepaths)}')

        filepaths = [str(v) for v in filepaths]

        imgs = []
        shapes = []
        if self.save_original_img:
            ori_imgs = []
        for filepath in filepaths:
            img_bytes = self.file_client.get(filepath)
            img = mmcv.imfrombytes(
                img_bytes, flag=self.flag,
                channel_order=self.channel_order)  # HWC
            if img.ndim == 2:
                img = np.expand_dims(img, axis=2)
            if self.save_original_img:

        results[self.key] = imgs
        results[f'{self.key}_path'] = filepaths
        results[f'{self.key}_ori_shape'] = shapes
        if self.save_original_img:
            results[f'ori_{self.key}'] = ori_imgs

        return results
Пример #4
class AudioDecodeInit(object):
    """Using librosa to initialize the audio reader.

        io_backend (str): io backend where frames are store.
            Default: 'disk'.
        sample_rate (int): Audio sampling times per second. Default: 16000.

    Required keys are "audio_path", added or modified keys are "length",
    "sample_rate", "audios".
    def __init__(self,
        self.io_backend = io_backend
        self.sample_rate = sample_rate
        if pad_method in ['random', 'zero']:
            self.pad_method = pad_method
            raise NotImplementedError
        self.kwargs = kwargs
        self.file_client = None

    def _zero_pad(self, shape):
        return np.zeros(shape, dtype=np.float32)

    def _random_pad(self, shape):
        # librosa load raw audio file into a distribution of -1~+1
        return np.random.rand(shape).astype(np.float32) * 2 - 1

    def __call__(self, results):
        """Perform the librosa initialization.

            results (dict): The resulting dict to be modified and passed
                to the next transform in pipeline.
            import librosa
        except ImportError:
            raise ImportError('Please install librosa first.')

        if self.file_client is None:
            self.file_client = FileClient(self.io_backend, **self.kwargs)
        if osp.exists(results['audio_path']):
            file_obj = io.BytesIO(self.file_client.get(results['audio_path']))
            y, sr = librosa.load(file_obj, sr=self.sample_rate)
            # Generate a random dummy 10s input
            pad_func = getattr(self, f'_{self.pad_method}_pad')
            y = pad_func(int(round(10.0 * self.sample_rate)))
            sr = self.sample_rate

        results['length'] = y.shape[0]
        results['sample_rate'] = sr
        results['audios'] = y
        return results
Пример #5
def load_fileclient_dist(filename, backend, map_location):
    """In distributed setting, this function only download checkpoint at local
    rank 0."""
    rank, world_size = get_dist_info()
    rank = int(os.environ.get('LOCAL_RANK', rank))
    allowed_backends = ['ceph']
    if backend not in allowed_backends:
        raise ValueError(f'Load from Backend {backend} is not supported.')
    if rank == 0:
        fileclient = FileClient(backend=backend)
        buffer = io.BytesIO(fileclient.get(filename))
        checkpoint = torch.load(buffer, map_location=map_location)
    if world_size > 1:
        if rank > 0:
            fileclient = FileClient(backend=backend)
            buffer = io.BytesIO(fileclient.get(filename))
            checkpoint = torch.load(buffer, map_location=map_location)
    return checkpoint
Пример #6
class RandomLoadResizeBg:
    """Randomly load a background image and resize it.

    Required key is "fg", added key is "bg".

        bg_dir (str): Path of directory to load background images from.
        io_backend (str): io backend where images are store. Default: 'disk'.
        flag (str): Loading flag for images. Default: 'color'.
        channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'.
            Default: 'bgr'.
        kwargs (dict): Args for file client.
    def __init__(self,
        self.bg_dir = bg_dir
        self.bg_list = list(mmcv.scandir(bg_dir))
        self.io_backend = io_backend
        self.flag = flag
        self.channel_order = channel_order
        self.kwargs = kwargs
        self.file_client = None

    def __call__(self, results):
        """Call function.

            results (dict): A dict containing the necessary information and
                data for augmentation.

            dict: A dict containing the processed data and information.
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend, **self.kwargs)
        h, w = results['fg'].shape[:2]
        idx = np.random.randint(len(self.bg_list))
        filepath = Path(self.bg_dir).joinpath(self.bg_list[idx])
        img_bytes = self.file_client.get(filepath)
        img = mmcv.imfrombytes(img_bytes,
                               channel_order=self.channel_order)  # HWC
        bg = mmcv.imresize(img, (w, h), interpolation='bicubic')
        results['bg'] = bg
        return results

    def __repr__(self):
        return self.__class__.__name__ + f"(bg_dir='{self.bg_dir}')"
Пример #7
class PyAVInit(object):
    """Using pyav to initialize the video.

    PyAV: https://github.com/mikeboers/PyAV

    Required keys are "filename",
    added or modified keys are "video_reader", and "total_frames".

        io_backend (str): io backend where frames are store.
            Default: 'disk'.
        kwargs (dict): Args for file client.
    def __init__(self, io_backend='disk', **kwargs):
        self.io_backend = io_backend
        self.kwargs = kwargs
        self.file_client = None

    def __call__(self, results):
        """Perform the PyAV initiation.

            results (dict): The resulting dict to be modified and passed
                to the next transform in pipeline.
            import av
        except ImportError:
            raise ImportError('Please run "conda install av -c conda-forge" '
                              'or "pip install av" to install PyAV first.')

        if self.file_client is None:
            self.file_client = FileClient(self.io_backend, **self.kwargs)

        file_obj = io.BytesIO(self.file_client.get(results['filename']))
        container = av.open(file_obj)

        results['video_reader'] = container
        # results['total_frames'] = container.streams.video[0].frames
                results['total_frames'] = container.streams.video[0].frames
            except KeyError:
        except IndexError:

        return results
Пример #8
class OpenCVInit(object):
    """Using OpenCV to initalize the video_reader.

    Required keys are "filename", added or modified keys are "new_path",
    "video_reader" and "total_frames".

    def __init__(self, io_backend='disk', **kwargs):
        self.io_backend = io_backend
        self.kwargs = kwargs
        self.file_client = None
        random_string = get_random_string()
        thread_id = get_thread_id()
        self.tmp_folder = osp.join(get_shm_dir(),

    def __call__(self, results):
        """Perform the OpenCV initiation.

            results (dict): The resulting dict to be modified and passed
                to the next transform in pipeline.
        if self.io_backend == 'disk':
            new_path = results['filename']
            if self.file_client is None:
                self.file_client = FileClient(self.io_backend, **self.kwargs)

            thread_id = get_thread_id()
            # save the file of same thread at the same place
            new_path = osp.join(self.tmp_folder, f'tmp_{thread_id}.mp4')
            with open(new_path, 'wb') as f:

        container = mmcv.VideoReader(new_path)
        results['new_path'] = new_path
        results['video_reader'] = container
        results['total_frames'] = len(container)

        return results

    def __del__(self):
Пример #9
class CompositeFg:
    """Composite foreground with a random foreground.

    This class composites the current training sample with additional data
    randomly (could be from the same dataset). With probability 0.5, the sample
    will be composited with a random sample from the specified directory.
    The composition is performed as:

    .. math::
        fg_{new} = \\alpha_1 * fg_1 + (1 - \\alpha_1) * fg_2

        \\alpha_{new} = 1 - (1 - \\alpha_1) * (1 - \\alpha_2)

    where :math:`(fg_1, \\alpha_1)` is from the current sample and
    :math:`(fg_2, \\alpha_2)` is the randomly loaded sample. With the above
    composition, :math:`\\alpha_{new}` is still in `[0, 1]`.

    Required keys are "alpha" and "fg". Modified keys are "alpha" and "fg".

        fg_dirs (str | list[str]): Path of directories to load foreground
            images from.
        alpha_dirs (str | list[str]): Path of directories to load alpha mattes
        interpolation (str): Interpolation method of `mmcv.imresize` to resize
            the randomly loaded images.
    def __init__(self,
        self.fg_dirs = fg_dirs if isinstance(fg_dirs, list) else [fg_dirs]
        self.alpha_dirs = alpha_dirs if isinstance(alpha_dirs,
                                                   list) else [alpha_dirs]
        self.interpolation = interpolation

        self.fg_list, self.alpha_list = self._get_file_list(
            self.fg_dirs, self.alpha_dirs)
        self.io_backend = io_backend
        self.file_client = None
        self.kwargs = kwargs

    def __call__(self, results):
        """Call function.

            results (dict): A dict containing the necessary information and
                data for augmentation.

            dict: A dict containing the processed data and information.
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend, **self.kwargs)
        fg = results['fg']
        alpha = results['alpha'].astype(np.float32) / 255.
        h, w = results['fg'].shape[:2]

        # randomly select fg
        if np.random.rand() < 0.5:
            idx = np.random.randint(len(self.fg_list))
            fg2_bytes = self.file_client.get(self.fg_list[idx])
            fg2 = mmcv.imfrombytes(fg2_bytes)
            alpha2_bytes = self.file_client.get(self.alpha_list[idx])
            alpha2 = mmcv.imfrombytes(alpha2_bytes, flag='grayscale')
            alpha2 = alpha2.astype(np.float32) / 255.

            fg2 = mmcv.imresize(fg2, (w, h), interpolation=self.interpolation)
            alpha2 = mmcv.imresize(alpha2, (w, h),

            # the overlap of two 50% transparency will be 75%
            alpha_tmp = 1 - (1 - alpha) * (1 - alpha2)
            # if the result alpha is all-one, then we avoid composition
            if np.any(alpha_tmp < 1):
                # composite fg with fg2
                fg = fg.astype(np.float32) * alpha[..., None] \
                     + fg2.astype(np.float32) * (1 - alpha[..., None])
                alpha = alpha_tmp

        results['fg'] = fg
        results['alpha'] = (alpha * 255).astype(np.uint8)
        return results

    def _get_file_list(fg_dirs, alpha_dirs):
        all_fg_list = list()
        all_alpha_list = list()
        for fg_dir, alpha_dir in zip(fg_dirs, alpha_dirs):
            fg_list = sorted(mmcv.scandir(fg_dir))
            alpha_list = sorted(mmcv.scandir(alpha_dir))
            # we assume the file names for fg and alpha are the same
            assert len(fg_list) == len(alpha_list), (
                f'{fg_dir} and {alpha_dir} should have the same number of '
                f'images ({len(fg_list)} differs from ({len(alpha_list)})')
            fg_list = [osp.join(fg_dir, fg) for fg in fg_list]
            alpha_list = [osp.join(alpha_dir, alpha) for alpha in alpha_list]

        return all_fg_list, all_alpha_list

    def __repr__(self):
        repr_str = self.__class__.__name__
        repr_str += (f'(fg_dirs={self.fg_dirs}, alpha_dirs={self.alpha_dirs}, '
        return repr_str
Пример #10
class LoadPairedImageFromFile(LoadImageFromFile):
    """Load a pair of images from file.

    Each sample contains a pair of images, which are concatenated in the w
    dimension (a|b). This is a special loading class for generation paired
    dataset. It loads a pair of images as the common loader does and crops
    it into two images with the same shape in different domains.

    Required key is "pair_path". Added or modified keys are "pair",
    "pair_ori_shape", "ori_pair", "img_a", "img_b", "img_a_path",
    "img_b_path", "img_a_ori_shape", "img_b_ori_shape", "ori_img_a" and

        io_backend (str): io backend where images are store. Default: 'disk'.
        key (str): Keys in results to find corresponding path. Default: 'gt'.
        flag (str): Loading flag for images. Default: 'color'.
        channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'.
            Default: 'bgr'.
        save_original_img (bool): If True, maintain a copy of the image in
            `results` dict with name of `f'ori_{key}'`. Default: False.
        kwargs (dict): Args for file client.
    def __call__(self, results):
        """Call function.

            results (dict): A dict containing the necessary information and
                data for augmentation.

            dict: A dict containing the processed data and information.
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend, **self.kwargs)
        filepath = str(results[f'{self.key}_path'])
        img_bytes = self.file_client.get(filepath)
        img = mmcv.imfrombytes(img_bytes, flag=self.flag)  # HWC, BGR
        if img.ndim == 2:
            img = np.expand_dims(img, axis=2)

        results[self.key] = img
        results[f'{self.key}_path'] = filepath
        results[f'{self.key}_ori_shape'] = img.shape
        if self.save_original_img:
            results[f'ori_{self.key}'] = img.copy()

        # crop pair into a and b
        w = img.shape[1]
        if w % 2 != 0:
            raise ValueError(
                f'The width of image pair must be even number, but got {w}.')
        new_w = w // 2
        img_a = img[:, :new_w, :]
        img_b = img[:, new_w:, :]

        results['img_a'] = img_a
        results['img_b'] = img_b
        results['img_a_path'] = filepath
        results['img_b_path'] = filepath
        results['img_a_ori_shape'] = img_a.shape
        results['img_b_ori_shape'] = img_b.shape
        if self.save_original_img:
            results['ori_img_a'] = img_a.copy()
            results['ori_img_b'] = img_b.copy()

        return results
Пример #11
class LoadMask(object):
    """Load Mask for multiple types.

    For different types of mask, users need to provide the corresponding
    config dict.

    Example config for bbox:

    .. code-block:: python

        config = dict(img_shape=(256, 256), max_bbox_shape=128)

    Example config for irregular:

    .. code-block:: python

        config = dict(
            img_shape=(256, 256),
            num_vertexes=(4, 12),
            length_range=(10, 100),
            brush_width=(10, 40),
            area_ratio_range=(0.15, 0.5))

    Example config for ff:

    .. code-block:: python

        config = dict(
            img_shape=(256, 256),
            num_vertexes=(4, 12),
            brush_width=(12, 40))

    Example config for set:

    .. code-block:: python

        config = dict(

        The mask_list_file contains the list of mask file name like this:

        The prefix gives the data path.

        mask_mode (str): Mask mode in ['bbox', 'irregular', 'ff', 'set',
            * bbox: square bounding box masks.
            * irregular: irregular holes.
            * ff: free-form holes from DeepFillv2.
            * set: randomly get a mask from a mask set.
            * file: get mask from 'mask_path' in results.
        mask_config (dict): Params for creating masks. Each type of mask needs
            different configs.
    def __init__(self, mask_mode='bbox', mask_config=None):
        self.mask_mode = mask_mode
        self.mask_config = dict() if mask_config is None else mask_config
        assert isinstance(self.mask_config, dict)

        # set init info if needed in some modes

    def _init_info(self):
        if self.mask_mode == 'set':
            # get mask list information
            self.mask_list = []
            mask_list_file = self.mask_config['mask_list_file']
            with open(mask_list_file, 'r') as f:
                for line in f:
                    line_split = line.strip().split(' ')
                    mask_name = line_split[0]
            self.mask_set_size = len(self.mask_list)
            self.io_backend = self.mask_config['io_backend']
            self.flag = self.mask_config['flag']
            self.file_client_kwargs = self.mask_config['file_client_kwargs']
            self.file_client = None
        elif self.mask_mode == 'file':
            self.io_backend = 'disk'
            self.flag = 'unchanged'
            self.file_client_kwargs = dict()
            self.file_client = None

    def _get_random_mask_from_set(self):
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend,
        # minus 1 to avoid out of range error
        mask_idx = np.random.randint(0, self.mask_set_size)

        mask_bytes = self.file_client.get(self.mask_list[mask_idx])
        mask = mmcv.imfrombytes(mask_bytes, flag=self.flag)  # HWC, BGR
        if mask.ndim == 2:
            mask = np.expand_dims(mask, axis=2)
            mask = mask[:, :, 0:1]

        mask[mask > 0] = 1.
        return mask

    def _get_mask_from_file(self, path):
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend,
        mask_bytes = self.file_client.get(path)
        mask = mmcv.imfrombytes(mask_bytes, flag=self.flag)  # HWC, BGR
        if mask.ndim == 2:
            mask = np.expand_dims(mask, axis=2)
            mask = mask[:, :, 0:1]

        mask[mask > 0] = 1.
        return mask

    def __call__(self, results):
        """Call function.

            results (dict): A dict containing the necessary information and
                data for augmentation.

            dict: A dict containing the processed data and information.

        if self.mask_mode == 'bbox':
            mask_bbox = random_bbox(**self.mask_config)
            mask = bbox2mask(self.mask_config['img_shape'], mask_bbox)
            results['mask_bbox'] = mask_bbox
        elif self.mask_mode == 'irregular':
            mask = get_irregular_mask(**self.mask_config)
        elif self.mask_mode == 'set':
            mask = self._get_random_mask_from_set()
        elif self.mask_mode == 'ff':
            mask = brush_stroke_mask(**self.mask_config)
        elif self.mask_mode == 'file':
            mask = self._get_mask_from_file(results['mask_path'])
            raise NotImplementedError(
                f'Mask mode {self.mask_mode} has not been implemented.')
        results['mask'] = mask
        return results

    def __repr__(self):
        return self.__class__.__name__ + f"(mask_mode='{self.mask_mode}')"
Пример #12
class LoadImageFromFile:
    """Load image from file.

        io_backend (str): io backend where images are store. Default: 'disk'.
        key (str): Keys in results to find corresponding path. Default: 'gt'.
        flag (str): Loading flag for images. Default: 'color'.
        channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'.
            Default: 'bgr'.
        convert_to (str | None): The color space of the output image. If None,
            no conversion is conducted. Default: None.
        save_original_img (bool): If True, maintain a copy of the image in
            `results` dict with name of `f'ori_{key}'`. Default: False.
        use_cache (bool): If True, load all images at once. Default: False.
        backend (str): The image loading backend type. Options are `cv2`,
            `pillow`, and 'turbojpeg'. Default: None.
        kwargs (dict): Args for file client.
    def __init__(self,

        self.io_backend = io_backend
        self.key = key
        self.flag = flag
        self.save_original_img = save_original_img
        self.channel_order = channel_order
        self.convert_to = convert_to
        self.kwargs = kwargs
        self.file_client = None
        self.use_cache = use_cache
        self.cache = None
        self.backend = backend

    def __call__(self, results):
        """Call function.

            results (dict): A dict containing the necessary information and
                data for augmentation.

            dict: A dict containing the processed data and information.
        filepath = str(results[f'{self.key}_path'])
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend, **self.kwargs)
        if self.use_cache:
            if self.cache is None:
                self.cache = dict()
            if filepath in self.cache:
                img = self.cache[filepath]
                img_bytes = self.file_client.get(filepath)
                img = mmcv.imfrombytes(img_bytes,
                                       backend=self.backend)  # HWC
                self.cache[filepath] = img
            img_bytes = self.file_client.get(filepath)
            img = mmcv.imfrombytes(img_bytes,
                                   backend=self.backend)  # HWC

        if self.convert_to is not None:
            if self.channel_order == 'bgr' and self.convert_to.lower() == 'y':
                img = mmcv.bgr2ycbcr(img, y_only=True)
            elif self.channel_order == 'rgb':
                img = mmcv.rgb2ycbcr(img, y_only=True)
                raise ValueError('Currently support only "bgr2ycbcr" or '
            if img.ndim == 2:
                img = np.expand_dims(img, axis=2)

        results[self.key] = img
        results[f'{self.key}_path'] = filepath
        results[f'{self.key}_ori_shape'] = img.shape
        if self.save_original_img:
            results[f'ori_{self.key}'] = img.copy()

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        repr_str += (
            f'(io_backend={self.io_backend}, key={self.key}, '
            f'flag={self.flag}, save_original_img={self.save_original_img}, '
            f'channel_order={self.channel_order}, use_cache={self.use_cache})')
        return repr_str
Пример #13
class FrameSelector(object):
    """Select raw frames with given indices.

    Required keys are "frame_dir", "filename_tmpl" and "frame_inds",
    added or modified keys are "imgs", "img_shape" and "original_shape".

        io_backend (str): IO backend where frames are stored. Default: 'disk'.
        decoding_backend (str): Backend used for image decoding.
            Default: 'cv2'.
        kwargs (dict, optional): Arguments for FileClient.
    def __init__(self, io_backend='disk', decoding_backend='cv2', **kwargs):
        self.io_backend = io_backend
        self.decoding_backend = decoding_backend
        self.kwargs = kwargs
        self.file_client = None

    def __call__(self, results):
        """Perform the FrameSelector selecting given indices.

            results (dict): The resulting dict to be modified and passed
                to the next transform in pipeline.

        directory = results['frame_dir']
        filename_tmpl = results['filename_tmpl']
        modality = results['modality']

        if self.file_client is None:
            self.file_client = FileClient(self.io_backend, **self.kwargs)

        imgs = list()

        if results['frame_inds'].ndim != 1:
            results['frame_inds'] = np.squeeze(results['frame_inds'])

        offset = results.get('offset', 0)

        for frame_idx in results['frame_inds']:
            frame_idx += offset
            if modality == 'RGB':
                filepath = osp.join(directory, filename_tmpl.format(frame_idx))
                img_bytes = self.file_client.get(filepath)
                # Get frame with channel order RGB directly.
                cur_frame = mmcv.imfrombytes(img_bytes, channel_order='rgb')
            elif modality == 'Flow':
                x_filepath = osp.join(directory,
                                      filename_tmpl.format('x', frame_idx))
                y_filepath = osp.join(directory,
                                      filename_tmpl.format('y', frame_idx))
                x_img_bytes = self.file_client.get(x_filepath)
                x_frame = mmcv.imfrombytes(x_img_bytes, flag='grayscale')
                y_img_bytes = self.file_client.get(y_filepath)
                y_frame = mmcv.imfrombytes(y_img_bytes, flag='grayscale')
                imgs.extend([x_frame, y_frame])
                raise NotImplementedError

        results['imgs'] = imgs
        results['original_shape'] = imgs[0].shape[:2]
        results['img_shape'] = imgs[0].shape[:2]

        return results
Пример #14
class LoadKineticsPose:
    """Load Kinetics Pose given filename (The format should be pickle)

    Required keys are "filename", "total_frames", "img_shape", "frame_inds",
    "anno_inds" (for mmpose source, optional), added or modified keys are
    "keypoint", "keypoint_score".

        io_backend (str): IO backend where frames are stored. Default: 'disk'.
        squeeze (bool): Whether to remove frames with no human pose.
            Default: True.
        max_person (int): The max number of persons in a frame. Default: 10.
        keypoint_weight (dict): The weight of keypoints. We set the confidence
            score of a person as the weighted sum of confidence scores of each
            joint. Persons with low confidence scores are dropped (if exceed
            max_person). Default: dict(face=1, torso=2, limb=3).
        source (str): The sources of the keypoints used. Choices are 'mmpose'
            and 'openpose'. Default: 'mmpose'.
        kwargs (dict, optional): Arguments for FileClient.

    def __init__(self,
                 keypoint_weight=dict(face=1, torso=2, limb=3),

        self.io_backend = io_backend
        self.squeeze = squeeze
        self.max_person = max_person
        self.keypoint_weight = cp.deepcopy(keypoint_weight)
        self.source = source

        if source == 'openpose':
            self.kpsubset = dict(
                face=[0, 14, 15, 16, 17],
                torso=[1, 2, 8, 5, 11],
                limb=[3, 4, 6, 7, 9, 10, 12, 13])
        elif source == 'mmpose':
            self.kpsubset = dict(
                face=[0, 1, 2, 3, 4],
                torso=[5, 6, 11, 12],
                limb=[7, 8, 9, 10, 13, 14, 15, 16])
            raise NotImplementedError('Unknown source of Kinetics Pose')

        self.kwargs = kwargs
        self.file_client = None

    def __call__(self, results):

        assert 'filename' in results
        filename = results.pop('filename')

        # only applicable to source == 'mmpose'
        anno_inds = None
        if 'anno_inds' in results:
            assert self.source == 'mmpose'
            anno_inds = results.pop('anno_inds')
        results.pop('box_score', None)

        if self.file_client is None:
            self.file_client = FileClient(self.io_backend, **self.kwargs)

        bytes = self.file_client.get(filename)

        # only the kp array is in the pickle file, each kp include x, y, score.
        kps = pickle.loads(bytes)

        total_frames = results['total_frames']

        frame_inds = results.pop('frame_inds')

        if anno_inds is not None:
            kps = kps[anno_inds]
            frame_inds = frame_inds[anno_inds]

        frame_inds = list(frame_inds)

        def mapinds(inds):
            uni = np.unique(inds)
            mapp = {x: i for i, x in enumerate(uni)}
            inds = [mapp[x] for x in inds]
            return np.array(inds, dtype=np.int16)

        if self.squeeze:
            frame_inds = mapinds(frame_inds)
            total_frames = np.max(frame_inds) + 1

        # write it back
        results['total_frames'] = total_frames

        h, w = results['img_shape']
        if self.source == 'openpose':
            kps[:, :, 0] *= w
            kps[:, :, 1] *= h

        num_kp = kps.shape[1]
        num_person = mode(frame_inds)[-1][0]

        new_kp = np.zeros([num_person, total_frames, num_kp, 2],
        new_kpscore = np.zeros([num_person, total_frames, num_kp],
        # 32768 is enough
        num_person_frame = np.zeros([total_frames], dtype=np.int16)

        for frame_ind, kp in zip(frame_inds, kps):
            person_ind = num_person_frame[frame_ind]
            new_kp[person_ind, frame_ind] = kp[:, :2]
            new_kpscore[person_ind, frame_ind] = kp[:, 2]
            num_person_frame[frame_ind] += 1

        kpgrp = self.kpsubset
        weight = self.keypoint_weight
        results['num_person'] = num_person

        if num_person > self.max_person:
            for i in range(total_frames):
                np_frame = num_person_frame[i]
                val = new_kpscore[:np_frame, i]

                val = (
                    np.sum(val[:, kpgrp['face']], 1) * weight['face'] +
                    np.sum(val[:, kpgrp['torso']], 1) * weight['torso'] +
                    np.sum(val[:, kpgrp['limb']], 1) * weight['limb'])
                inds = sorted(range(np_frame), key=lambda x: -val[x])
                new_kpscore[:np_frame, i] = new_kpscore[inds, i]
                new_kp[:np_frame, i] = new_kp[inds, i]
            results['num_person'] = self.max_person

        results['keypoint'] = new_kp[:self.max_person]
        results['keypoint_score'] = new_kpscore[:self.max_person]
        return results

    def __repr__(self):
        repr_str = (f'{self.__class__.__name__}('
                    f'io_backend={self.io_backend}, '
                    f'squeeze={self.squeeze}, '
                    f'max_person={self.max_person}, '
                    f'keypoint_weight={self.keypoint_weight}, '
                    f'source={self.source}, '
        return repr_str