예제 #1
0
def extract_inception_features(dataloader,
                               inception,
                               num_samples,
                               inception_style='pytorch'):
    """Extract inception features for FID metric.

    Args:
        dataloader (:obj:`DataLoader`): Dataloader for images.
        inception (nn.Module): Inception network.
        num_samples (int): The number of samples to be extracted.
        inception_style (str): The style of Inception network, "pytorch" or
            "stylegan". Defaults to "pytorch".

    Returns:
        torch.Tensor: Inception features.
    """
    batch_size = dataloader.batch_size
    num_iters = num_samples // batch_size
    if num_iters * batch_size < num_samples:
        num_iters += 1
    # define mmcv progress bar
    pbar = mmcv.ProgressBar(num_iters)

    feature_list = []
    curr_iter = 1
    for data in dataloader:
        img = data['real_img']
        pbar.update()

        # the inception network is not wrapped with module wrapper.
        if not is_module_wrapper(inception):
            # put the img to the module device
            img = img.to(get_module_device(inception))

        if inception_style == 'stylegan':
            img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
            feature = inception(img, return_features=True)
        else:
            feature = inception(img)[0].view(img.shape[0], -1)
        feature_list.append(feature.to('cpu'))

        if curr_iter >= num_iters:
            break
        curr_iter += 1

    # Attention: the number of features may be different as you want.
    features = torch.cat(feature_list, 0)

    assert features.shape[0] >= num_samples
    features = features[:num_samples]

    # to change the line after pbar
    sys.stdout.write('\n')
    return features
예제 #2
0
    def _from_numpy(self, data):
        if isinstance(data, list):
            return [self._from_numpy(x) for x in data]

        if isinstance(data, np.ndarray):
            data = torch.from_numpy(data)
            device = get_module_device(self.generator)
            data = data.to(device)
            return data

        return data
    def make_injected_noise(self):
        """make noises that will be injected into feature maps.

        Returns:
            list[Tensor]: List of layer-wise noise tensor.
        """
        device = get_module_device(self)

        noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]

        for i in range(3, self.log_size + 1):
            for _ in range(2):
                noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))

        return noises
예제 #4
0
    def sample_from_noise(self,
                          noise,
                          num_batches=0,
                          curr_scale=None,
                          sample_model='ema/orig',
                          **kwargs):
        """Sample images from noises by using the generator.

        Args:
            noise (torch.Tensor | callable | None): You can directly give a
                batch of noise through a ``torch.Tensor`` or offer a callable
                function to sample a batch of noise data. Otherwise, the
                ``None`` indicates to use the default noise sampler.
            num_batches (int, optional):  The number of batch size.
                Defaults to 0.

        Returns:
            torch.Tensor | dict: The output may be the direct synthesized \
                images in ``torch.Tensor``. Otherwise, a dict with queried \
                data, including generated images, will be returned.
        """
        # use `self.curr_scale` if curr_scale is None
        if curr_scale is None:
            curr_scale = self.curr_stage

        if sample_model == 'ema':
            assert self.use_ema
            _model = self.generator_ema
        elif sample_model == 'ema/orig' and self.use_ema:
            _model = self.generator_ema
        else:
            _model = self.generator

        if not self.fixed_noises[0].is_cuda and torch.cuda.is_available():
            self.fixed_noises = [
                x.to(get_module_device(self)) for x in self.fixed_noises
            ]

        outputs = _model(None,
                         fixed_noises=self.fixed_noises,
                         noise_weights=self.noise_weights,
                         rand_mode='rand',
                         num_batches=num_batches,
                         curr_scale=curr_scale,
                         **kwargs)

        return outputs
예제 #5
0
    def make_injected_noise(self, chosen_scale=0):
        device = get_module_device(self)

        base_scale = 2**2 + chosen_scale

        noises = [torch.randn(1, 1, base_scale, base_scale, device=device)]

        for i in range(3, self.log_size + 1):
            for n in range(2):
                _pad = 0
                if self.no_pad and not self.up_after_conv and n == 0:
                    _pad = 2
                noises.append(
                    torch.randn(1,
                                1,
                                base_scale * 2**(i - 2) + _pad,
                                base_scale * 2**(i - 2) + _pad,
                                device=device))

        return noises
예제 #6
0
    def __init__(self,
                 generator,
                 num_images,
                 batch_size,
                 space='W',
                 sampling='end',
                 epsilon=1e-4,
                 latent_dim=512):
        assert space in ['Z', 'W']
        assert sampling in ['full', 'end']
        n_batch = num_images // batch_size

        resid = num_images - (n_batch * batch_size)
        self.batch_sizes = [batch_size] * n_batch + ([resid]
                                                     if resid > 0 else [])
        self.device = get_module_device(generator)
        self.generator = generator
        self.latent_dim = latent_dim
        self.space = space
        self.sampling = sampling
        self.epsilon = epsilon
예제 #7
0
def sample_from_path(generator,
                     latent_a,
                     latent_b,
                     label_a,
                     label_b,
                     intervals,
                     embedding_name=None,
                     interp_mode='lerp',
                     **kwargs):
    interp_alphas = torch.linspace(0, 1, intervals)
    interp_samples = []

    device = get_module_device(generator)
    if embedding_name is None:
        generator_name = generator.__class__.__name__
        assert generator_name in _default_embedding_name
        embedding_name = _default_embedding_name[generator_name]
    embedding_fn = getattr(generator, embedding_name, nn.Identity())
    embedding_a = embedding_fn(label_a.to(device))
    embedding_b = embedding_fn(label_b.to(device))

    for alpha in interp_alphas:
        # calculate latent interpolation
        if interp_mode == 'lerp':
            latent_interp = torch.lerp(latent_a, latent_b, alpha)
        else:
            assert latent_a.ndim == latent_b.ndim == 2
            latent_interp = slerp(latent_a, latent_b, alpha)

        # calculate embedding interpolation
        embedding_interp = embedding_a + (
            embedding_b - embedding_a) * alpha.to(embedding_a.dtype)
        if isinstance(generator, (BigGANDeepGenerator, BigGANGenerator)):
            kwargs.update(dict(use_outside_embedding=True))
        sample = batch_inference(generator, latent_interp, embedding_interp,
                                 **kwargs)
        interp_samples.append(sample)

    return interp_samples
    def forward(self,
                styles,
                num_batches=-1,
                return_noise=False,
                return_latents=False,
                inject_index=None,
                truncation=1,
                truncation_latent=None,
                input_is_latent=False,
                injected_noise=None,
                randomize_noise=True,
                transition_weight=1.,
                curr_scale=-1):
        """Forward function.

        This function has been integrated with the truncation trick. Please
        refer to the usage of `truncation` and `truncation_latent`.

        Args:
            styles (torch.Tensor | list[torch.Tensor] | callable | None): In
                StyleGAN1, you can provide noise tensor or latent tensor. Given
                a list containing more than one noise or latent tensors, style
                mixing trick will be used in training. Of course, You can
                directly give a batch of noise through a ``torch.Tensor`` or
                offer a callable function to sample a batch of noise data.
                Otherwise, the ``None`` indicates to use the default noise
                sampler.
            num_batches (int, optional): The number of batch size.
                Defaults to 0.
            return_noise (bool, optional): If True, ``noise_batch`` will be
                returned in a dict with ``fake_img``. Defaults to False.
            return_latents (bool, optional): If True, ``latent`` will be
                returned in a dict with ``fake_img``. Defaults to False.
            inject_index (int | None, optional): The index number for mixing
                style codes. Defaults to None.
            truncation (float, optional): Truncation factor. Give value less
                than 1., the truncation trick will be adopted. Defaults to 1.
            truncation_latent (torch.Tensor, optional): Mean truncation latent.
                Defaults to None.
            input_is_latent (bool, optional): If `True`, the input tensor is
                the latent tensor. Defaults to False.
            injected_noise (torch.Tensor | None, optional): Given a tensor, the
                random noise will be fixed as this input injected noise.
                Defaults to None.
            randomize_noise (bool, optional): If `False`, images are sampled
                with the buffered noise tensor injected to the style conv
                block. Defaults to True.
            transition_weight (float, optional): The weight used in resolution
                transition. Defaults to 1..
            curr_scale (int, optional): The resolution scale of generated image
                tensor. -1 means the max resolution scale of the StyleGAN1.
                Defaults to -1.

        Returns:
            torch.Tensor | dict: Generated image tensor or dictionary \
                containing more data.
        """
        # receive noise and conduct sanity check.
        if isinstance(styles, torch.Tensor):
            assert styles.shape[1] == self.style_channels
            styles = [styles]
        elif mmcv.is_seq_of(styles, torch.Tensor):
            for t in styles:
                assert t.shape[-1] == self.style_channels
        # receive a noise generator and sample noise.
        elif callable(styles):
            device = get_module_device(self)
            noise_generator = styles
            assert num_batches > 0
            if self.default_style_mode == 'mix' and random.random(
            ) < self.mix_prob:
                styles = [
                    noise_generator((num_batches, self.style_channels))
                    for _ in range(2)
                ]
            else:
                styles = [noise_generator((num_batches, self.style_channels))]
            styles = [s.to(device) for s in styles]
        # otherwise, we will adopt default noise sampler.
        else:
            device = get_module_device(self)
            assert num_batches > 0 and not input_is_latent
            if self.default_style_mode == 'mix' and random.random(
            ) < self.mix_prob:
                styles = [
                    torch.randn((num_batches, self.style_channels))
                    for _ in range(2)
                ]
            else:
                styles = [torch.randn((num_batches, self.style_channels))]
            styles = [s.to(device) for s in styles]

        if not input_is_latent:
            noise_batch = styles
            styles = [self.style_mapping(s) for s in styles]
        else:
            noise_batch = None

        if injected_noise is None:
            if randomize_noise:
                injected_noise = [None] * self.num_injected_noises
            else:
                injected_noise = [
                    getattr(self, f'injected_noise_{i}')
                    for i in range(self.num_injected_noises)
                ]
        # use truncation trick
        if truncation < 1:
            style_t = []
            # calculate truncation latent on the fly
            if truncation_latent is None and not hasattr(
                    self, 'truncation_latent'):
                self.truncation_latent = self.get_mean_latent()
                truncation_latent = self.truncation_latent
            elif truncation_latent is None and hasattr(self,
                                                       'truncation_latent'):
                truncation_latent = self.truncation_latent

            for style in styles:
                style_t.append(truncation_latent + truncation *
                               (style - truncation_latent))

            styles = style_t
        # no style mixing
        if len(styles) < 2:
            inject_index = self.num_latents

            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)

            else:
                latent = styles[0]
        # style mixing
        else:
            if inject_index is None:
                inject_index = random.randint(1, self.num_latents - 1)

            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            latent2 = styles[1].unsqueeze(1).repeat(
                1, self.num_latents - inject_index, 1)

            latent = torch.cat([latent, latent2], 1)

        curr_log_size = self.log_size if curr_scale < 0 else int(
            math.log2(curr_scale))
        step = curr_log_size - 2

        _index = 0
        out = latent
        # 4x4 ---> higher resolutions
        for i, (conv, to_rgb) in enumerate(zip(self.convs, self.to_rgbs)):
            if i > 0 and step > 0:
                out_prev = out
            out = conv(out,
                       latent[:, _index],
                       latent[:, _index + 1],
                       noise1=injected_noise[2 * i],
                       noise2=injected_noise[2 * i + 1])
            if i == step:
                out = to_rgb(out)

                if i > 0 and 0 <= transition_weight < 1:
                    skip_rgb = self.to_rgbs[i - 1](out_prev)
                    skip_rgb = F.interpolate(skip_rgb,
                                             scale_factor=2,
                                             mode='nearest')
                    out = (1 - transition_weight
                           ) * skip_rgb + transition_weight * out
                break

            _index += 2

        img = out

        if return_latents or return_noise:
            output_dict = dict(fake_img=img,
                               latent=latent,
                               inject_index=inject_index,
                               noise_batch=noise_batch)
            return output_dict

        return img
    def forward(self,
                styles,
                num_batches=-1,
                return_noise=False,
                return_latents=False,
                inject_index=None,
                truncation=1,
                truncation_latent=None,
                input_is_latent=False,
                injected_noise=None,
                randomize_noise=True):
        """Forward function.

        This function has been integrated with the truncation trick. Please
        refer to the usage of `truncation` and `truncation_latent`.

        Args:
            styles (torch.Tensor | list[torch.Tensor] | callable | None): In
                StyleGAN2, you can provide noise tensor or latent tensor. Given
                a list containing more than one noise or latent tensors, style
                mixing trick will be used in training. Of course, You can
                directly give a batch of noise through a ``torch.Tensor`` or
                offer a callable function to sample a batch of noise data.
                Otherwise, the ``None`` indicates to use the default noise
                sampler.
            num_batches (int, optional): The number of batch size.
                Defaults to 0.
            return_noise (bool, optional): If True, ``noise_batch`` will be
                returned in a dict with ``fake_img``. Defaults to False.
            return_latents (bool, optional): If True, ``latent`` will be
                returned in a dict with ``fake_img``. Defaults to False.
            inject_index (int | None, optional): The index number for mixing
                style codes. Defaults to None.
            truncation (float, optional): Truncation factor. Give value less
                than 1., the truncation trick will be adopted. Defaults to 1.
            truncation_latent (torch.Tensor, optional): Mean truncation latent.
                Defaults to None.
            input_is_latent (bool, optional): If `True`, the input tensor is
                the latent tensor. Defaults to False.
            injected_noise (torch.Tensor | None, optional): Given a tensor, the
                random noise will be fixed as this input injected noise.
                Defaults to None.
            randomize_noise (bool, optional): If `False`, images are sampled
                with the buffered noise tensor injected to the style conv
                block. Defaults to True.

        Returns:
            torch.Tensor | dict: Generated image tensor or dictionary \
                containing more data.
        """
        # receive noise and conduct sanity check.
        if isinstance(styles, torch.Tensor):
            assert styles.shape[1] == self.style_channels
            styles = [styles]
        elif mmcv.is_seq_of(styles, torch.Tensor):
            for t in styles:
                assert t.shape[-1] == self.style_channels
        # receive a noise generator and sample noise.
        elif callable(styles):
            device = get_module_device(self)
            noise_generator = styles
            assert num_batches > 0
            if self.default_style_mode == 'mix' and random.random(
            ) < self.mix_prob:
                styles = [
                    noise_generator((num_batches, self.style_channels))
                    for _ in range(2)
                ]
            else:
                styles = [noise_generator((num_batches, self.style_channels))]
            styles = [s.to(device) for s in styles]
        # otherwise, we will adopt default noise sampler.
        else:
            device = get_module_device(self)
            assert num_batches > 0 and not input_is_latent
            if self.default_style_mode == 'mix' and random.random(
            ) < self.mix_prob:
                styles = [
                    torch.randn((num_batches, self.style_channels))
                    for _ in range(2)
                ]
            else:
                styles = [torch.randn((num_batches, self.style_channels))]
            styles = [s.to(device) for s in styles]

        if not input_is_latent:
            noise_batch = styles
            styles = [self.style_mapping(s) for s in styles]
        else:
            noise_batch = None

        if injected_noise is None:
            if randomize_noise:
                injected_noise = [None] * self.num_injected_noises
            else:
                injected_noise = [
                    getattr(self, f'injected_noise_{i}')
                    for i in range(self.num_injected_noises)
                ]
        # use truncation trick
        if truncation < 1:
            style_t = []
            # calculate truncation latent on the fly
            if truncation_latent is None and not hasattr(
                    self, 'truncation_latent'):
                self.truncation_latent = self.get_mean_latent()
                truncation_latent = self.truncation_latent
            elif truncation_latent is None and hasattr(self,
                                                       'truncation_latent'):
                truncation_latent = self.truncation_latent

            for style in styles:
                style_t.append(truncation_latent + truncation *
                               (style - truncation_latent))

            styles = style_t
        # no style mixing
        if len(styles) < 2:
            inject_index = self.num_latents

            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)

            else:
                latent = styles[0]
        # style mixing
        else:
            if inject_index is None:
                inject_index = random.randint(1, self.num_latents - 1)

            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            latent2 = styles[1].unsqueeze(1).repeat(
                1, self.num_latents - inject_index, 1)

            latent = torch.cat([latent, latent2], 1)

        # 4x4 stage
        out = self.constant_input(latent)
        out = self.conv1(out, latent[:, 0], noise=injected_noise[0])
        skip = self.to_rgb1(out, latent[:, 1])

        _index = 1

        # 8x8 ---> higher resolutions
        for up_conv, conv, noise1, noise2, to_rgb in zip(
                self.convs[::2], self.convs[1::2], injected_noise[1::2],
                injected_noise[2::2], self.to_rgbs):
            out = up_conv(out, latent[:, _index], noise=noise1)
            out = conv(out, latent[:, _index + 1], noise=noise2)
            skip = to_rgb(out, latent[:, _index + 2], skip)
            _index += 2

        # make sure the output image is torch.float32 to avoid RunTime Error
        # in other modules
        img = skip.to(torch.float32)

        if return_latents or return_noise:
            output_dict = dict(fake_img=img,
                               latent=latent,
                               inject_index=inject_index,
                               noise_batch=noise_batch)
            return output_dict

        return img
예제 #10
0
def extract_inception_features(dataloader,
                               inception,
                               num_samples,
                               inception_style='pytorch'):
    """Extract inception features for FID metric.

    Args:
        dataloader (:obj:`DataLoader`): Dataloader for images.
        inception (nn.Module): Inception network.
        num_samples (int): The number of samples to be extracted.
        inception_style (str): The style of Inception network, "pytorch" or
            "stylegan". Defaults to "pytorch".

    Returns:
        torch.Tensor: Inception features.
    """
    batch_size = dataloader.batch_size
    num_iters = num_samples // batch_size
    if num_iters * batch_size < num_samples:
        num_iters += 1
    # define mmcv progress bar
    pbar = mmcv.ProgressBar(num_iters)

    feature_list = []
    curr_iter = 1
    for data in dataloader:
        # a dirty walkround to support multiple datasets (mainly for the
        # unconditional dataset and conditional dataset). In our
        # implementation, unconditioanl dataset will return real images with
        # the key "real_img". However, the conditional dataset contains a key
        # "img" denoting the real images.
        if 'real_img' in data:
            # Mainly for the unconditional dataset in our MMGeneration
            img = data['real_img']
        else:
            # Mainly for conditional dataset in MMClassification
            img = data['img']
        pbar.update()

        # the inception network is not wrapped with module wrapper.
        if not is_module_wrapper(inception):
            # put the img to the module device
            img = img.to(get_module_device(inception))

        if inception_style == 'stylegan':
            img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
            feature = inception(img, return_features=True)
        else:
            feature = inception(img)[0].view(img.shape[0], -1)
        feature_list.append(feature.to('cpu'))

        if curr_iter >= num_iters:
            break
        curr_iter += 1

    # Attention: the number of features may be different as you want.
    features = torch.cat(feature_list, 0)

    assert features.shape[0] >= num_samples
    features = features[:num_samples]

    # to change the line after pbar
    sys.stdout.write('\n')
    return features
    def forward(self,
                noise,
                num_batches=0,
                input_is_latent=False,
                truncation=1,
                num_truncation_layer=None,
                update_emas=False,
                force_fp32=True,
                return_noise=False,
                return_latents=False):
        """Forward Function for stylegan3.

        Args:
            noise (torch.Tensor | callable | None): You can directly give a
                batch of noise through a ``torch.Tensor`` or offer a callable
                function to sample a batch of noise data. Otherwise, the
                ``None`` indicates to use the default noise sampler.
            num_batches (int, optional): The number of batch size.
                Defaults to 0.
            input_is_latent (bool, optional): If `True`, the input tensor is
                the latent tensor. Defaults to False.
            truncation (float, optional): Truncation factor. Give value less
                than 1., the truncation trick will be adopted. Defaults to 1.
            num_truncation_layer (int, optional): Number of layers use
                truncated latent. Defaults to None.
            update_emas (bool, optional): Whether update moving average of
                mean latent. Defaults to False.
            force_fp32 (bool, optional): Force fp32 ignore the weights.
                Defaults to True.
            return_noise (bool, optional): If True, ``noise_batch`` will be
                returned in a dict with ``fake_img``. Defaults to False.
            return_latents (bool, optional): If True, ``latent`` will be
                returned in a dict with ``fake_img``. Defaults to False.
        Returns:
            torch.Tensor | dict: Generated image tensor or dictionary \
                containing more data.
        """
        # if input is latent, set noise size as the style_channels
        noise_size = (self.style_channels
                      if input_is_latent else self.noise_size)

        if isinstance(noise, torch.Tensor):
            assert noise.shape[1] == noise_size
            assert noise.ndim == 2, ('The noise should be in shape of (n, c), '
                                     f'but got {noise.shape}')
            noise_batch = noise

        # receive a noise generator and sample noise.
        elif callable(noise):
            noise_generator = noise
            assert num_batches > 0
            noise_batch = noise_generator((num_batches, noise_size))

        # otherwise, we will adopt default noise sampler.
        else:
            assert num_batches > 0
            noise_batch = torch.randn((num_batches, noise_size))

        device = get_module_device(self)
        noise_batch = noise_batch.to(device)

        if input_is_latent:
            ws = noise_batch.unsqueeze(1).repeat([1, self.num_ws, 1])
        else:
            ws = self.style_mapping(noise_batch,
                                    truncation=truncation,
                                    num_truncation_layer=num_truncation_layer,
                                    update_emas=update_emas)
        out_img = self.synthesis(ws,
                                 update_emas=update_emas,
                                 force_fp32=force_fp32)

        if self.rgb2bgr:
            out_img = out_img[:, [2, 1, 0], ...]

        if return_noise or return_latents:
            output = dict(fake_img=out_img, noise_batch=noise_batch, latent=ws)
            return output

        return out_img
예제 #12
0
def batch_inference(generator,
                    noise,
                    embedding=None,
                    num_batches=-1,
                    max_batch_size=16,
                    dict_key=None,
                    **kwargs):
    """Inference function to get a batch of desired data from output dictionary
    of generator.

    Args:
        generator (nn.Module): Generator of a conditional model.
        noise (Tensor | list[torch.tensor] | None): A batch of noise
            Tensor.
        embedding (Tensor, optional): Embedding tensor of label for
            conditional models. Defaults to None.
        num_batches (int, optional): The number of batchs for
            inference. Defaults to -1.
        max_batch_size (int, optional): The number of batch size for
            inference. Defaults to 16.
        dict_key (str, optional): key used to get results from output
            dictionary of generator. Defaults to None.

    Returns:
        torch.Tensor: Tensor of output image, noise batch or label
            batch.
    """
    # split noise into groups
    if noise is not None:
        if isinstance(noise, torch.Tensor):
            num_batches = noise.shape[0]
            noise_group = torch.split(noise, max_batch_size, 0)
        else:
            num_batches = noise[0].shape[0]
            noise_group = torch.split(noise[0], max_batch_size, 0)
            noise_group = [[noise_tensor] for noise_tensor in noise_group]
    else:
        noise_group = [None] * (
            num_batches // max_batch_size +
            (1 if num_batches % max_batch_size > 0 else 0))

    # split embedding into groups
    if embedding is not None:
        assert isinstance(embedding, torch.Tensor)
        num_batches = embedding.shape[0]
        embedding_group = torch.split(embedding, max_batch_size, 0)
    else:
        embedding_group = [None] * (
            num_batches // max_batch_size +
            (1 if num_batches % max_batch_size > 0 else 0))

    # split batchsize into groups
    batchsize_group = [max_batch_size] * (num_batches // max_batch_size)
    if num_batches % max_batch_size > 0:
        batchsize_group += [num_batches % max_batch_size]

    device = get_module_device(generator)
    outputs = []
    for _noise, _embedding, _num_batches in zip(noise_group, embedding_group,
                                                batchsize_group):
        if isinstance(_noise, torch.Tensor):
            _noise = _noise.to(device)
        if isinstance(_noise, list):
            _noise = [ele.to(device) for ele in _noise]
        if _embedding is not None:
            _embedding = _embedding.to(device)
        output = generator(
            _noise, label=_embedding, num_batches=_num_batches, **kwargs)
        output = output[dict_key] if dict_key else output
        if isinstance(output, list):
            output = output[0]
        # once obtaining sampled results, we immediately put them into cpu
        # to save cuda memory
        outputs.append(output.to('cpu'))
    outputs = torch.cat(outputs, dim=0)
    return outputs