示例#1
0
    def __init__(self, channel, size=4):
        super().__init__()
        if isinstance(size, int):
            size = [size, size]
        elif mmcv.is_seq_of(size, int):
            assert len(
                size
            ) == 2, f'The length of size should be 2 but got {len(size)}'
        else:
            raise ValueError(f'Got invalid value in size, {size}')

        self.input = nn.Parameter(torch.randn(1, channel, *size))
示例#2
0
 def __init__(self,
              buffer_type: type = Buffer,
              buffers: Optional[Dict] = None):
     self.buffer_type = buffer_type
     if buffers is None:
         self._buffers = {}
     else:
         if is_seq_of(list(buffers.values()), buffer_type):
             self._buffers = buffers.copy()
         else:
             raise ValueError('The values of buffers should be instance '
                              f'of {buffer_type}')
示例#3
0
def test_is_seq_of():
    assert mmcv.is_seq_of([1.0, 2.0, 3.0], float)
    assert mmcv.is_seq_of([(1, ), (2, ), (3, )], tuple)
    assert mmcv.is_seq_of((1.0, 2.0, 3.0), float)
    assert mmcv.is_list_of([1.0, 2.0, 3.0], float)
    assert not mmcv.is_seq_of((1.0, 2.0, 3.0), float, seq_type=list)
    assert not mmcv.is_tuple_of([1.0, 2.0, 3.0], float)
    assert not mmcv.is_seq_of([1.0, 2, 3], int)
    assert not mmcv.is_seq_of((1.0, 2, 3), int)
    def forward(self,
                styles,
                num_batches=-1,
                return_noise=False,
                return_latents=False,
                inject_index=None,
                truncation=1,
                truncation_latent=None,
                input_is_latent=False,
                injected_noise=None,
                randomize_noise=True,
                transition_weight=1.,
                curr_scale=-1):
        """Forward function.

        This function has been integrated with the truncation trick. Please
        refer to the usage of `truncation` and `truncation_latent`.

        Args:
            styles (torch.Tensor | list[torch.Tensor] | callable | None): In
                StyleGAN1, you can provide noise tensor or latent tensor. Given
                a list containing more than one noise or latent tensors, style
                mixing trick will be used in training. Of course, You can
                directly give a batch of noise through a ``torch.Tensor`` or
                offer a callable function to sample a batch of noise data.
                Otherwise, the ``None`` indicates to use the default noise
                sampler.
            num_batches (int, optional): The number of batch size.
                Defaults to 0.
            return_noise (bool, optional): If True, ``noise_batch`` will be
                returned in a dict with ``fake_img``. Defaults to False.
            return_latents (bool, optional): If True, ``latent`` will be
                returned in a dict with ``fake_img``. Defaults to False.
            inject_index (int | None, optional): The index number for mixing
                style codes. Defaults to None.
            truncation (float, optional): Truncation factor. Give value less
                than 1., the truncation trick will be adopted. Defaults to 1.
            truncation_latent (torch.Tensor, optional): Mean truncation latent.
                Defaults to None.
            input_is_latent (bool, optional): If `True`, the input tensor is
                the latent tensor. Defaults to False.
            injected_noise (torch.Tensor | None, optional): Given a tensor, the
                random noise will be fixed as this input injected noise.
                Defaults to None.
            randomize_noise (bool, optional): If `False`, images are sampled
                with the buffered noise tensor injected to the style conv
                block. Defaults to True.
            transition_weight (float, optional): The weight used in resolution
                transition. Defaults to 1..
            curr_scale (int, optional): The resolution scale of generated image
                tensor. -1 means the max resolution scale of the StyleGAN1.
                Defaults to -1.

        Returns:
            torch.Tensor | dict: Generated image tensor or dictionary \
                containing more data.
        """
        # receive noise and conduct sanity check.
        if isinstance(styles, torch.Tensor):
            assert styles.shape[1] == self.style_channels
            styles = [styles]
        elif mmcv.is_seq_of(styles, torch.Tensor):
            for t in styles:
                assert t.shape[-1] == self.style_channels
        # receive a noise generator and sample noise.
        elif callable(styles):
            device = get_module_device(self)
            noise_generator = styles
            assert num_batches > 0
            if self.default_style_mode == 'mix' and random.random(
            ) < self.mix_prob:
                styles = [
                    noise_generator((num_batches, self.style_channels))
                    for _ in range(2)
                ]
            else:
                styles = [noise_generator((num_batches, self.style_channels))]
            styles = [s.to(device) for s in styles]
        # otherwise, we will adopt default noise sampler.
        else:
            device = get_module_device(self)
            assert num_batches > 0 and not input_is_latent
            if self.default_style_mode == 'mix' and random.random(
            ) < self.mix_prob:
                styles = [
                    torch.randn((num_batches, self.style_channels))
                    for _ in range(2)
                ]
            else:
                styles = [torch.randn((num_batches, self.style_channels))]
            styles = [s.to(device) for s in styles]

        if not input_is_latent:
            noise_batch = styles
            styles = [self.style_mapping(s) for s in styles]
        else:
            noise_batch = None

        if injected_noise is None:
            if randomize_noise:
                injected_noise = [None] * self.num_injected_noises
            else:
                injected_noise = [
                    getattr(self, f'injected_noise_{i}')
                    for i in range(self.num_injected_noises)
                ]
        # use truncation trick
        if truncation < 1:
            style_t = []
            # calculate truncation latent on the fly
            if truncation_latent is None and not hasattr(
                    self, 'truncation_latent'):
                self.truncation_latent = self.get_mean_latent()
                truncation_latent = self.truncation_latent
            elif truncation_latent is None and hasattr(self,
                                                       'truncation_latent'):
                truncation_latent = self.truncation_latent

            for style in styles:
                style_t.append(truncation_latent + truncation *
                               (style - truncation_latent))

            styles = style_t
        # no style mixing
        if len(styles) < 2:
            inject_index = self.num_latents

            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)

            else:
                latent = styles[0]
        # style mixing
        else:
            if inject_index is None:
                inject_index = random.randint(1, self.num_latents - 1)

            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            latent2 = styles[1].unsqueeze(1).repeat(
                1, self.num_latents - inject_index, 1)

            latent = torch.cat([latent, latent2], 1)

        curr_log_size = self.log_size if curr_scale < 0 else int(
            math.log2(curr_scale))
        step = curr_log_size - 2

        _index = 0
        out = latent
        # 4x4 ---> higher resolutions
        for i, (conv, to_rgb) in enumerate(zip(self.convs, self.to_rgbs)):
            if i > 0 and step > 0:
                out_prev = out
            out = conv(out,
                       latent[:, _index],
                       latent[:, _index + 1],
                       noise1=injected_noise[2 * i],
                       noise2=injected_noise[2 * i + 1])
            if i == step:
                out = to_rgb(out)

                if i > 0 and 0 <= transition_weight < 1:
                    skip_rgb = self.to_rgbs[i - 1](out_prev)
                    skip_rgb = F.interpolate(skip_rgb,
                                             scale_factor=2,
                                             mode='nearest')
                    out = (1 - transition_weight
                           ) * skip_rgb + transition_weight * out
                break

            _index += 2

        img = out

        if return_latents or return_noise:
            output_dict = dict(fake_img=img,
                               latent=latent,
                               inject_index=inject_index,
                               noise_batch=noise_batch)
            return output_dict

        return img
    def forward(self,
                styles,
                num_batches=-1,
                return_noise=False,
                return_latents=False,
                inject_index=None,
                truncation=1,
                truncation_latent=None,
                input_is_latent=False,
                injected_noise=None,
                randomize_noise=True):
        """Forward function.

        This function has been integrated with the truncation trick. Please
        refer to the usage of `truncation` and `truncation_latent`.

        Args:
            styles (torch.Tensor | list[torch.Tensor] | callable | None): In
                StyleGAN2, you can provide noise tensor or latent tensor. Given
                a list containing more than one noise or latent tensors, style
                mixing trick will be used in training. Of course, You can
                directly give a batch of noise through a ``torch.Tensor`` or
                offer a callable function to sample a batch of noise data.
                Otherwise, the ``None`` indicates to use the default noise
                sampler.
            num_batches (int, optional): The number of batch size.
                Defaults to 0.
            return_noise (bool, optional): If True, ``noise_batch`` will be
                returned in a dict with ``fake_img``. Defaults to False.
            return_latents (bool, optional): If True, ``latent`` will be
                returned in a dict with ``fake_img``. Defaults to False.
            inject_index (int | None, optional): The index number for mixing
                style codes. Defaults to None.
            truncation (float, optional): Truncation factor. Give value less
                than 1., the truncation trick will be adopted. Defaults to 1.
            truncation_latent (torch.Tensor, optional): Mean truncation latent.
                Defaults to None.
            input_is_latent (bool, optional): If `True`, the input tensor is
                the latent tensor. Defaults to False.
            injected_noise (torch.Tensor | None, optional): Given a tensor, the
                random noise will be fixed as this input injected noise.
                Defaults to None.
            randomize_noise (bool, optional): If `False`, images are sampled
                with the buffered noise tensor injected to the style conv
                block. Defaults to True.

        Returns:
            torch.Tensor | dict: Generated image tensor or dictionary \
                containing more data.
        """
        # receive noise and conduct sanity check.
        if isinstance(styles, torch.Tensor):
            assert styles.shape[1] == self.style_channels
            styles = [styles]
        elif mmcv.is_seq_of(styles, torch.Tensor):
            for t in styles:
                assert t.shape[-1] == self.style_channels
        # receive a noise generator and sample noise.
        elif callable(styles):
            device = get_module_device(self)
            noise_generator = styles
            assert num_batches > 0
            if self.default_style_mode == 'mix' and random.random(
            ) < self.mix_prob:
                styles = [
                    noise_generator((num_batches, self.style_channels))
                    for _ in range(2)
                ]
            else:
                styles = [noise_generator((num_batches, self.style_channels))]
            styles = [s.to(device) for s in styles]
        # otherwise, we will adopt default noise sampler.
        else:
            device = get_module_device(self)
            assert num_batches > 0 and not input_is_latent
            if self.default_style_mode == 'mix' and random.random(
            ) < self.mix_prob:
                styles = [
                    torch.randn((num_batches, self.style_channels))
                    for _ in range(2)
                ]
            else:
                styles = [torch.randn((num_batches, self.style_channels))]
            styles = [s.to(device) for s in styles]

        if not input_is_latent:
            noise_batch = styles
            styles = [self.style_mapping(s) for s in styles]
        else:
            noise_batch = None

        if injected_noise is None:
            if randomize_noise:
                injected_noise = [None] * self.num_injected_noises
            else:
                injected_noise = [
                    getattr(self, f'injected_noise_{i}')
                    for i in range(self.num_injected_noises)
                ]
        # use truncation trick
        if truncation < 1:
            style_t = []
            # calculate truncation latent on the fly
            if truncation_latent is None and not hasattr(
                    self, 'truncation_latent'):
                self.truncation_latent = self.get_mean_latent()
                truncation_latent = self.truncation_latent
            elif truncation_latent is None and hasattr(self,
                                                       'truncation_latent'):
                truncation_latent = self.truncation_latent

            for style in styles:
                style_t.append(truncation_latent + truncation *
                               (style - truncation_latent))

            styles = style_t
        # no style mixing
        if len(styles) < 2:
            inject_index = self.num_latents

            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)

            else:
                latent = styles[0]
        # style mixing
        else:
            if inject_index is None:
                inject_index = random.randint(1, self.num_latents - 1)

            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            latent2 = styles[1].unsqueeze(1).repeat(
                1, self.num_latents - inject_index, 1)

            latent = torch.cat([latent, latent2], 1)

        # 4x4 stage
        out = self.constant_input(latent)
        out = self.conv1(out, latent[:, 0], noise=injected_noise[0])
        skip = self.to_rgb1(out, latent[:, 1])

        _index = 1

        # 8x8 ---> higher resolutions
        for up_conv, conv, noise1, noise2, to_rgb in zip(
                self.convs[::2], self.convs[1::2], injected_noise[1::2],
                injected_noise[2::2], self.to_rgbs):
            out = up_conv(out, latent[:, _index], noise=noise1)
            out = conv(out, latent[:, _index + 1], noise=noise2)
            skip = to_rgb(out, latent[:, _index + 2], skip)
            _index += 2

        # make sure the output image is torch.float32 to avoid RunTime Error
        # in other modules
        img = skip.to(torch.float32)

        if return_latents or return_noise:
            output_dict = dict(fake_img=img,
                               latent=latent,
                               inject_index=inject_index,
                               noise_batch=noise_batch)
            return output_dict

        return img
示例#6
0
    def smooth(self, results):
        """Apply temporal smoothing on pose estimation sequences.

        Args:
            results (list[dict] | list[list[dict]]): The pose results of a
                single frame (non-nested list) or multiple frames (nested
                list). The result of each target is a dict, which should
                contains:

                - track_id (optional, Any): The track ID of the target
                - keypoints (np.ndarray): The keypoint coordinates in [K, C]

        Returns:
            (list[dict] | list[list[dict]]): Temporal smoothed pose results,
            which has the same data structure as the input's.
        """

        # Check if input is empty
        if not (results) or not (results[0]):
            warnings.warn('Smoother received empty result.')
            return results

        # Check input is single frame or sequence
        if is_seq_of(results, dict):
            single_frame = True
            results = [results]
        else:
            assert is_seq_of(results, list)
            single_frame = False

        # Get temporal length of input
        T = len(results)

        # Collate the input results to pose sequences
        poses = self._collate_pose(results)

        # Smooth the pose sequence of each target
        smoothed_poses = {}
        update_history = {}
        for track_id, pose in poses.items():
            if track_id in self.history:
                # For tracked target, get its filter and pose history
                pose_history, pose_filter = self.history[track_id]
                if self.padding_size > 0:
                    # Pad the pose sequence with pose history
                    pose = np.concatenate((pose_history, pose), axis=0)
            else:
                # For new target, build a new filter
                pose_filter = self._get_filter()

            # Update the history information
            if self.padding_size > 0:
                pose_history = pose[-self.padding_size:].copy()
            else:
                pose_history = None
            update_history[track_id] = (pose_history, pose_filter)

            # Smooth the pose sequence with the filter
            smoothed_pose = pose_filter(pose)
            smoothed_poses[track_id] = smoothed_pose[-T:]

        self.history = update_history

        # Scatter the pose sequences back to the format of results
        smoothed_results = self._scatter_pose(results, smoothed_poses)

        # If the input is single frame, remove the nested list to keep the
        # output structure consistent with the input's
        if single_frame:
            smoothed_results = smoothed_results[0]
        return smoothed_results