def extract_inception_features(dataloader, inception, num_samples, inception_style='pytorch'): """Extract inception features for FID metric. Args: dataloader (:obj:`DataLoader`): Dataloader for images. inception (nn.Module): Inception network. num_samples (int): The number of samples to be extracted. inception_style (str): The style of Inception network, "pytorch" or "stylegan". Defaults to "pytorch". Returns: torch.Tensor: Inception features. """ batch_size = dataloader.batch_size num_iters = num_samples // batch_size if num_iters * batch_size < num_samples: num_iters += 1 # define mmcv progress bar pbar = mmcv.ProgressBar(num_iters) feature_list = [] curr_iter = 1 for data in dataloader: img = data['real_img'] pbar.update() # the inception network is not wrapped with module wrapper. if not is_module_wrapper(inception): # put the img to the module device img = img.to(get_module_device(inception)) if inception_style == 'stylegan': img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) feature = inception(img, return_features=True) else: feature = inception(img)[0].view(img.shape[0], -1) feature_list.append(feature.to('cpu')) if curr_iter >= num_iters: break curr_iter += 1 # Attention: the number of features may be different as you want. features = torch.cat(feature_list, 0) assert features.shape[0] >= num_samples features = features[:num_samples] # to change the line after pbar sys.stdout.write('\n') return features
def _from_numpy(self, data): if isinstance(data, list): return [self._from_numpy(x) for x in data] if isinstance(data, np.ndarray): data = torch.from_numpy(data) device = get_module_device(self.generator) data = data.to(device) return data return data
def make_injected_noise(self): """make noises that will be injected into feature maps. Returns: list[Tensor]: List of layer-wise noise tensor. """ device = get_module_device(self) noises = [torch.randn(1, 1, 2**2, 2**2, device=device)] for i in range(3, self.log_size + 1): for _ in range(2): noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) return noises
def sample_from_noise(self, noise, num_batches=0, curr_scale=None, sample_model='ema/orig', **kwargs): """Sample images from noises by using the generator. Args: noise (torch.Tensor | callable | None): You can directly give a batch of noise through a ``torch.Tensor`` or offer a callable function to sample a batch of noise data. Otherwise, the ``None`` indicates to use the default noise sampler. num_batches (int, optional): The number of batch size. Defaults to 0. Returns: torch.Tensor | dict: The output may be the direct synthesized \ images in ``torch.Tensor``. Otherwise, a dict with queried \ data, including generated images, will be returned. """ # use `self.curr_scale` if curr_scale is None if curr_scale is None: curr_scale = self.curr_stage if sample_model == 'ema': assert self.use_ema _model = self.generator_ema elif sample_model == 'ema/orig' and self.use_ema: _model = self.generator_ema else: _model = self.generator if not self.fixed_noises[0].is_cuda and torch.cuda.is_available(): self.fixed_noises = [ x.to(get_module_device(self)) for x in self.fixed_noises ] outputs = _model(None, fixed_noises=self.fixed_noises, noise_weights=self.noise_weights, rand_mode='rand', num_batches=num_batches, curr_scale=curr_scale, **kwargs) return outputs
def make_injected_noise(self, chosen_scale=0): device = get_module_device(self) base_scale = 2**2 + chosen_scale noises = [torch.randn(1, 1, base_scale, base_scale, device=device)] for i in range(3, self.log_size + 1): for n in range(2): _pad = 0 if self.no_pad and not self.up_after_conv and n == 0: _pad = 2 noises.append( torch.randn(1, 1, base_scale * 2**(i - 2) + _pad, base_scale * 2**(i - 2) + _pad, device=device)) return noises
def __init__(self, generator, num_images, batch_size, space='W', sampling='end', epsilon=1e-4, latent_dim=512): assert space in ['Z', 'W'] assert sampling in ['full', 'end'] n_batch = num_images // batch_size resid = num_images - (n_batch * batch_size) self.batch_sizes = [batch_size] * n_batch + ([resid] if resid > 0 else []) self.device = get_module_device(generator) self.generator = generator self.latent_dim = latent_dim self.space = space self.sampling = sampling self.epsilon = epsilon
def sample_from_path(generator, latent_a, latent_b, label_a, label_b, intervals, embedding_name=None, interp_mode='lerp', **kwargs): interp_alphas = torch.linspace(0, 1, intervals) interp_samples = [] device = get_module_device(generator) if embedding_name is None: generator_name = generator.__class__.__name__ assert generator_name in _default_embedding_name embedding_name = _default_embedding_name[generator_name] embedding_fn = getattr(generator, embedding_name, nn.Identity()) embedding_a = embedding_fn(label_a.to(device)) embedding_b = embedding_fn(label_b.to(device)) for alpha in interp_alphas: # calculate latent interpolation if interp_mode == 'lerp': latent_interp = torch.lerp(latent_a, latent_b, alpha) else: assert latent_a.ndim == latent_b.ndim == 2 latent_interp = slerp(latent_a, latent_b, alpha) # calculate embedding interpolation embedding_interp = embedding_a + ( embedding_b - embedding_a) * alpha.to(embedding_a.dtype) if isinstance(generator, (BigGANDeepGenerator, BigGANGenerator)): kwargs.update(dict(use_outside_embedding=True)) sample = batch_inference(generator, latent_interp, embedding_interp, **kwargs) interp_samples.append(sample) return interp_samples
def forward(self, styles, num_batches=-1, return_noise=False, return_latents=False, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False, injected_noise=None, randomize_noise=True, transition_weight=1., curr_scale=-1): """Forward function. This function has been integrated with the truncation trick. Please refer to the usage of `truncation` and `truncation_latent`. Args: styles (torch.Tensor | list[torch.Tensor] | callable | None): In StyleGAN1, you can provide noise tensor or latent tensor. Given a list containing more than one noise or latent tensors, style mixing trick will be used in training. Of course, You can directly give a batch of noise through a ``torch.Tensor`` or offer a callable function to sample a batch of noise data. Otherwise, the ``None`` indicates to use the default noise sampler. num_batches (int, optional): The number of batch size. Defaults to 0. return_noise (bool, optional): If True, ``noise_batch`` will be returned in a dict with ``fake_img``. Defaults to False. return_latents (bool, optional): If True, ``latent`` will be returned in a dict with ``fake_img``. Defaults to False. inject_index (int | None, optional): The index number for mixing style codes. Defaults to None. truncation (float, optional): Truncation factor. Give value less than 1., the truncation trick will be adopted. Defaults to 1. truncation_latent (torch.Tensor, optional): Mean truncation latent. Defaults to None. input_is_latent (bool, optional): If `True`, the input tensor is the latent tensor. Defaults to False. injected_noise (torch.Tensor | None, optional): Given a tensor, the random noise will be fixed as this input injected noise. Defaults to None. randomize_noise (bool, optional): If `False`, images are sampled with the buffered noise tensor injected to the style conv block. Defaults to True. transition_weight (float, optional): The weight used in resolution transition. Defaults to 1.. curr_scale (int, optional): The resolution scale of generated image tensor. -1 means the max resolution scale of the StyleGAN1. Defaults to -1. Returns: torch.Tensor | dict: Generated image tensor or dictionary \ containing more data. """ # receive noise and conduct sanity check. if isinstance(styles, torch.Tensor): assert styles.shape[1] == self.style_channels styles = [styles] elif mmcv.is_seq_of(styles, torch.Tensor): for t in styles: assert t.shape[-1] == self.style_channels # receive a noise generator and sample noise. elif callable(styles): device = get_module_device(self) noise_generator = styles assert num_batches > 0 if self.default_style_mode == 'mix' and random.random( ) < self.mix_prob: styles = [ noise_generator((num_batches, self.style_channels)) for _ in range(2) ] else: styles = [noise_generator((num_batches, self.style_channels))] styles = [s.to(device) for s in styles] # otherwise, we will adopt default noise sampler. else: device = get_module_device(self) assert num_batches > 0 and not input_is_latent if self.default_style_mode == 'mix' and random.random( ) < self.mix_prob: styles = [ torch.randn((num_batches, self.style_channels)) for _ in range(2) ] else: styles = [torch.randn((num_batches, self.style_channels))] styles = [s.to(device) for s in styles] if not input_is_latent: noise_batch = styles styles = [self.style_mapping(s) for s in styles] else: noise_batch = None if injected_noise is None: if randomize_noise: injected_noise = [None] * self.num_injected_noises else: injected_noise = [ getattr(self, f'injected_noise_{i}') for i in range(self.num_injected_noises) ] # use truncation trick if truncation < 1: style_t = [] # calculate truncation latent on the fly if truncation_latent is None and not hasattr( self, 'truncation_latent'): self.truncation_latent = self.get_mean_latent() truncation_latent = self.truncation_latent elif truncation_latent is None and hasattr(self, 'truncation_latent'): truncation_latent = self.truncation_latent for style in styles: style_t.append(truncation_latent + truncation * (style - truncation_latent)) styles = style_t # no style mixing if len(styles) < 2: inject_index = self.num_latents if styles[0].ndim < 3: latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: latent = styles[0] # style mixing else: if inject_index is None: inject_index = random.randint(1, self.num_latents - 1) latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) latent2 = styles[1].unsqueeze(1).repeat( 1, self.num_latents - inject_index, 1) latent = torch.cat([latent, latent2], 1) curr_log_size = self.log_size if curr_scale < 0 else int( math.log2(curr_scale)) step = curr_log_size - 2 _index = 0 out = latent # 4x4 ---> higher resolutions for i, (conv, to_rgb) in enumerate(zip(self.convs, self.to_rgbs)): if i > 0 and step > 0: out_prev = out out = conv(out, latent[:, _index], latent[:, _index + 1], noise1=injected_noise[2 * i], noise2=injected_noise[2 * i + 1]) if i == step: out = to_rgb(out) if i > 0 and 0 <= transition_weight < 1: skip_rgb = self.to_rgbs[i - 1](out_prev) skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest') out = (1 - transition_weight ) * skip_rgb + transition_weight * out break _index += 2 img = out if return_latents or return_noise: output_dict = dict(fake_img=img, latent=latent, inject_index=inject_index, noise_batch=noise_batch) return output_dict return img
def forward(self, styles, num_batches=-1, return_noise=False, return_latents=False, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False, injected_noise=None, randomize_noise=True): """Forward function. This function has been integrated with the truncation trick. Please refer to the usage of `truncation` and `truncation_latent`. Args: styles (torch.Tensor | list[torch.Tensor] | callable | None): In StyleGAN2, you can provide noise tensor or latent tensor. Given a list containing more than one noise or latent tensors, style mixing trick will be used in training. Of course, You can directly give a batch of noise through a ``torch.Tensor`` or offer a callable function to sample a batch of noise data. Otherwise, the ``None`` indicates to use the default noise sampler. num_batches (int, optional): The number of batch size. Defaults to 0. return_noise (bool, optional): If True, ``noise_batch`` will be returned in a dict with ``fake_img``. Defaults to False. return_latents (bool, optional): If True, ``latent`` will be returned in a dict with ``fake_img``. Defaults to False. inject_index (int | None, optional): The index number for mixing style codes. Defaults to None. truncation (float, optional): Truncation factor. Give value less than 1., the truncation trick will be adopted. Defaults to 1. truncation_latent (torch.Tensor, optional): Mean truncation latent. Defaults to None. input_is_latent (bool, optional): If `True`, the input tensor is the latent tensor. Defaults to False. injected_noise (torch.Tensor | None, optional): Given a tensor, the random noise will be fixed as this input injected noise. Defaults to None. randomize_noise (bool, optional): If `False`, images are sampled with the buffered noise tensor injected to the style conv block. Defaults to True. Returns: torch.Tensor | dict: Generated image tensor or dictionary \ containing more data. """ # receive noise and conduct sanity check. if isinstance(styles, torch.Tensor): assert styles.shape[1] == self.style_channels styles = [styles] elif mmcv.is_seq_of(styles, torch.Tensor): for t in styles: assert t.shape[-1] == self.style_channels # receive a noise generator and sample noise. elif callable(styles): device = get_module_device(self) noise_generator = styles assert num_batches > 0 if self.default_style_mode == 'mix' and random.random( ) < self.mix_prob: styles = [ noise_generator((num_batches, self.style_channels)) for _ in range(2) ] else: styles = [noise_generator((num_batches, self.style_channels))] styles = [s.to(device) for s in styles] # otherwise, we will adopt default noise sampler. else: device = get_module_device(self) assert num_batches > 0 and not input_is_latent if self.default_style_mode == 'mix' and random.random( ) < self.mix_prob: styles = [ torch.randn((num_batches, self.style_channels)) for _ in range(2) ] else: styles = [torch.randn((num_batches, self.style_channels))] styles = [s.to(device) for s in styles] if not input_is_latent: noise_batch = styles styles = [self.style_mapping(s) for s in styles] else: noise_batch = None if injected_noise is None: if randomize_noise: injected_noise = [None] * self.num_injected_noises else: injected_noise = [ getattr(self, f'injected_noise_{i}') for i in range(self.num_injected_noises) ] # use truncation trick if truncation < 1: style_t = [] # calculate truncation latent on the fly if truncation_latent is None and not hasattr( self, 'truncation_latent'): self.truncation_latent = self.get_mean_latent() truncation_latent = self.truncation_latent elif truncation_latent is None and hasattr(self, 'truncation_latent'): truncation_latent = self.truncation_latent for style in styles: style_t.append(truncation_latent + truncation * (style - truncation_latent)) styles = style_t # no style mixing if len(styles) < 2: inject_index = self.num_latents if styles[0].ndim < 3: latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: latent = styles[0] # style mixing else: if inject_index is None: inject_index = random.randint(1, self.num_latents - 1) latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) latent2 = styles[1].unsqueeze(1).repeat( 1, self.num_latents - inject_index, 1) latent = torch.cat([latent, latent2], 1) # 4x4 stage out = self.constant_input(latent) out = self.conv1(out, latent[:, 0], noise=injected_noise[0]) skip = self.to_rgb1(out, latent[:, 1]) _index = 1 # 8x8 ---> higher resolutions for up_conv, conv, noise1, noise2, to_rgb in zip( self.convs[::2], self.convs[1::2], injected_noise[1::2], injected_noise[2::2], self.to_rgbs): out = up_conv(out, latent[:, _index], noise=noise1) out = conv(out, latent[:, _index + 1], noise=noise2) skip = to_rgb(out, latent[:, _index + 2], skip) _index += 2 # make sure the output image is torch.float32 to avoid RunTime Error # in other modules img = skip.to(torch.float32) if return_latents or return_noise: output_dict = dict(fake_img=img, latent=latent, inject_index=inject_index, noise_batch=noise_batch) return output_dict return img
def extract_inception_features(dataloader, inception, num_samples, inception_style='pytorch'): """Extract inception features for FID metric. Args: dataloader (:obj:`DataLoader`): Dataloader for images. inception (nn.Module): Inception network. num_samples (int): The number of samples to be extracted. inception_style (str): The style of Inception network, "pytorch" or "stylegan". Defaults to "pytorch". Returns: torch.Tensor: Inception features. """ batch_size = dataloader.batch_size num_iters = num_samples // batch_size if num_iters * batch_size < num_samples: num_iters += 1 # define mmcv progress bar pbar = mmcv.ProgressBar(num_iters) feature_list = [] curr_iter = 1 for data in dataloader: # a dirty walkround to support multiple datasets (mainly for the # unconditional dataset and conditional dataset). In our # implementation, unconditioanl dataset will return real images with # the key "real_img". However, the conditional dataset contains a key # "img" denoting the real images. if 'real_img' in data: # Mainly for the unconditional dataset in our MMGeneration img = data['real_img'] else: # Mainly for conditional dataset in MMClassification img = data['img'] pbar.update() # the inception network is not wrapped with module wrapper. if not is_module_wrapper(inception): # put the img to the module device img = img.to(get_module_device(inception)) if inception_style == 'stylegan': img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) feature = inception(img, return_features=True) else: feature = inception(img)[0].view(img.shape[0], -1) feature_list.append(feature.to('cpu')) if curr_iter >= num_iters: break curr_iter += 1 # Attention: the number of features may be different as you want. features = torch.cat(feature_list, 0) assert features.shape[0] >= num_samples features = features[:num_samples] # to change the line after pbar sys.stdout.write('\n') return features
def forward(self, noise, num_batches=0, input_is_latent=False, truncation=1, num_truncation_layer=None, update_emas=False, force_fp32=True, return_noise=False, return_latents=False): """Forward Function for stylegan3. Args: noise (torch.Tensor | callable | None): You can directly give a batch of noise through a ``torch.Tensor`` or offer a callable function to sample a batch of noise data. Otherwise, the ``None`` indicates to use the default noise sampler. num_batches (int, optional): The number of batch size. Defaults to 0. input_is_latent (bool, optional): If `True`, the input tensor is the latent tensor. Defaults to False. truncation (float, optional): Truncation factor. Give value less than 1., the truncation trick will be adopted. Defaults to 1. num_truncation_layer (int, optional): Number of layers use truncated latent. Defaults to None. update_emas (bool, optional): Whether update moving average of mean latent. Defaults to False. force_fp32 (bool, optional): Force fp32 ignore the weights. Defaults to True. return_noise (bool, optional): If True, ``noise_batch`` will be returned in a dict with ``fake_img``. Defaults to False. return_latents (bool, optional): If True, ``latent`` will be returned in a dict with ``fake_img``. Defaults to False. Returns: torch.Tensor | dict: Generated image tensor or dictionary \ containing more data. """ # if input is latent, set noise size as the style_channels noise_size = (self.style_channels if input_is_latent else self.noise_size) if isinstance(noise, torch.Tensor): assert noise.shape[1] == noise_size assert noise.ndim == 2, ('The noise should be in shape of (n, c), ' f'but got {noise.shape}') noise_batch = noise # receive a noise generator and sample noise. elif callable(noise): noise_generator = noise assert num_batches > 0 noise_batch = noise_generator((num_batches, noise_size)) # otherwise, we will adopt default noise sampler. else: assert num_batches > 0 noise_batch = torch.randn((num_batches, noise_size)) device = get_module_device(self) noise_batch = noise_batch.to(device) if input_is_latent: ws = noise_batch.unsqueeze(1).repeat([1, self.num_ws, 1]) else: ws = self.style_mapping(noise_batch, truncation=truncation, num_truncation_layer=num_truncation_layer, update_emas=update_emas) out_img = self.synthesis(ws, update_emas=update_emas, force_fp32=force_fp32) if self.rgb2bgr: out_img = out_img[:, [2, 1, 0], ...] if return_noise or return_latents: output = dict(fake_img=out_img, noise_batch=noise_batch, latent=ws) return output return out_img
def batch_inference(generator, noise, embedding=None, num_batches=-1, max_batch_size=16, dict_key=None, **kwargs): """Inference function to get a batch of desired data from output dictionary of generator. Args: generator (nn.Module): Generator of a conditional model. noise (Tensor | list[torch.tensor] | None): A batch of noise Tensor. embedding (Tensor, optional): Embedding tensor of label for conditional models. Defaults to None. num_batches (int, optional): The number of batchs for inference. Defaults to -1. max_batch_size (int, optional): The number of batch size for inference. Defaults to 16. dict_key (str, optional): key used to get results from output dictionary of generator. Defaults to None. Returns: torch.Tensor: Tensor of output image, noise batch or label batch. """ # split noise into groups if noise is not None: if isinstance(noise, torch.Tensor): num_batches = noise.shape[0] noise_group = torch.split(noise, max_batch_size, 0) else: num_batches = noise[0].shape[0] noise_group = torch.split(noise[0], max_batch_size, 0) noise_group = [[noise_tensor] for noise_tensor in noise_group] else: noise_group = [None] * ( num_batches // max_batch_size + (1 if num_batches % max_batch_size > 0 else 0)) # split embedding into groups if embedding is not None: assert isinstance(embedding, torch.Tensor) num_batches = embedding.shape[0] embedding_group = torch.split(embedding, max_batch_size, 0) else: embedding_group = [None] * ( num_batches // max_batch_size + (1 if num_batches % max_batch_size > 0 else 0)) # split batchsize into groups batchsize_group = [max_batch_size] * (num_batches // max_batch_size) if num_batches % max_batch_size > 0: batchsize_group += [num_batches % max_batch_size] device = get_module_device(generator) outputs = [] for _noise, _embedding, _num_batches in zip(noise_group, embedding_group, batchsize_group): if isinstance(_noise, torch.Tensor): _noise = _noise.to(device) if isinstance(_noise, list): _noise = [ele.to(device) for ele in _noise] if _embedding is not None: _embedding = _embedding.to(device) output = generator( _noise, label=_embedding, num_batches=_num_batches, **kwargs) output = output[dict_key] if dict_key else output if isinstance(output, list): output = output[0] # once obtaining sampled results, we immediately put them into cpu # to save cuda memory outputs.append(output.to('cpu')) outputs = torch.cat(outputs, dim=0) return outputs