def forward(self, batch_size: int, left_borders_idx: Tensor) -> Tensor: misc.assert_shape(left_borders_idx, [batch_size]) noise = torch.randn(batch_size, self.channel_dim, self.resolution, self.resolution, device=left_borders_idx.device) out = self.coord_fuser(noise, left_borders_idx=left_borders_idx, memory_format=torch.contiguous_format) return out
def forward(self, ws, mask=None, **block_kwargs): if ws.ndim == 3: ws = ws.unsqueeze(1) block_ws = [] with torch.autograd.profiler.record_function('split_ws'): misc.assert_shape(ws, [None, None, self.num_ws, self.w_dim]) ws = ws.to(torch.float32) w_idx = 0 for res in self.block_resolutions: block = getattr(self, f'b{res}') block_ws.append(ws.narrow(2, w_idx, block.num_conv + block.num_torgb)) w_idx += block.num_conv if mask is None: mask = ws.new_ones([1, ws.shape[1], self.img_resolution, self.img_resolution]) / ws.shape[1] misc.assert_shape(mask, [None, ws.shape[1], self.img_resolution, self.img_resolution]) masks = [mask] for _ in range(len(self.block_resolutions) - 1): masks.insert(0, F.avg_pool2d(masks[0], 2)) x = img = None for res, cur_ws, cur_mask in zip(self.block_resolutions, block_ws, masks): block = getattr(self, f'b{res}') x, img = block(x, img, cur_ws, cur_mask, **block_kwargs) return img
def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): assert noise_mode in ['random', 'const', 'none'] in_resolution = self.resolution // self.up misc.assert_shape( x, [None, self.weight.shape[1], in_resolution, in_resolution]) styles = self.affine(w) noise = None if self.use_noise and noise_mode == 'random': noise = torch.randn( [x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength #noise += self.noise_const.expand_as(noise) * 0 if self.use_noise and noise_mode == 'const': noise = self.noise_const * self.noise_strength flip_weight = (self.up == 1) # slightly faster x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) return x
def forward(self, x, img, force_fp32=False): dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format # Input. if x is not None: misc.assert_shape( x, [None, self.in_channels, self.resolution, self.resolution]) x = x.to(dtype=dtype, memory_format=memory_format) # FromRGB. if self.in_channels == 0 or self.architecture == 'skip': misc.assert_shape( img, [None, self.img_channels, self.resolution, self.resolution]) img = img.to(dtype=dtype, memory_format=memory_format) y = self.fromrgb(img) x = x + y if x is not None else y img = upfirdn2d.downsample2d( img, self.resample_filter) if self.architecture == 'skip' else None # Main layers. if self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x) x = self.conv1(x, gain=np.sqrt(0.5)) x = y.add_(x) else: x = self.conv0(x) x = self.conv1(x) assert x.dtype == dtype return x, img
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): # Embed, normalize, and concat inputs. x = None with torch.autograd.profiler.record_function('input'): if self.z_dim > 0: misc.assert_shape(z, [None, self.z_dim]) x = normalize_2nd_moment(z.to(torch.float32)) if self.c_dim > 0: misc.assert_shape(c, [None, self.c_dim]) y = normalize_2nd_moment(self.embed(c.to(torch.float32))) x = torch.cat([x, y], dim=1) if x is not None else y # Main layers. for idx in range(self.num_layers): layer = getattr(self, f'fc{idx}') x = layer(x) # Update moving average of W. if self.w_avg_beta is not None and self.training and not skip_w_avg_update: with torch.autograd.profiler.record_function('update_w_avg'): self.w_avg.copy_(x.detach().mean(dim=0).lerp( self.w_avg, self.w_avg_beta)) self.w_cov.copy_((self.w_avg_beta * self.w_cov) + ( (self.w_avg_beta - self.w_avg_beta**2) * (x.detach() - self.w_avg).T @ (x.detach() - self.w_avg))) with torch.autograd.profiler.record_function('update_w_avg'): self.w_avg.copy_(x.detach().mean(dim=0).lerp( self.w_avg, self.w_avg_beta)) self.w_cov.copy_((self.w_avg_beta * self.w_cov) + ( (self.w_avg_beta - self.w_avg_beta**2) * (x.detach() - self.w_avg).T @ (x.detach() - self.w_avg))) # Broadcast. if self.num_ws is not None: with torch.autograd.profiler.record_function('broadcast'): x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) # Apply truncation. if truncation_psi != 1: with torch.autograd.profiler.record_function('truncate'): assert self.w_avg_beta is not None if self.num_ws is None or truncation_cutoff is None: x = self.w_avg.lerp(x, truncation_psi) else: x[:, :truncation_cutoff] = self.w_avg.lerp( x[:, :truncation_cutoff], truncation_psi) return x
def modulated_conv2d( x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. styles, # Modulation coefficients of shape [batch_size, in_channels]. noise = None, # Optional noise tensor to add to the output activations. up = 1, # Integer upsampling factor. down = 1, # Integer downsampling factor. padding = 0, # Padding with respect to the upsampled image. resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). demodulate = True, # Apply weight demodulation? flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? ): batch_size = x.shape[0] out_channels, in_channels, kh, kw = weight.shape misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] misc.assert_shape(styles, [batch_size, in_channels]) # [NI] # Pre-normalize inputs to avoid FP16 overflow. if x.dtype == torch.float16 and demodulate: weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I # Calculate per-sample weights and demodulation coefficients. w = None dcoefs = None if demodulate or fused_modconv: w = weight.unsqueeze(0) # [NOIkk] w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] if demodulate: dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] if demodulate and fused_modconv: w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] # Execute by scaling the activations before and after the convolution. if not fused_modconv: x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) if demodulate and noise is not None: x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) elif demodulate: x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) elif noise is not None: x = x.add_(noise.to(x.dtype)) return x # Execute as one fused op using grouped convolution. with misc.suppress_tracer_warnings(): # this value will be treated as a constant batch_size = int(batch_size) misc.assert_shape(x, [batch_size, in_channels, None, None]) x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) x = x.reshape(batch_size, -1, *x.shape[2:]) if noise is not None: x = x.add_(noise) return x
def forward(self, ws, c=None, **block_kwargs): block_ws = [] with torch.autograd.profiler.record_function('split_ws'): misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) ws = ws.to(torch.float32) w_idx = 0 for res in self.block_resolutions: block = getattr(self, f'b{res}') block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) w_idx += block.num_conv x = img = None for res, cur_ws in zip(self.block_resolutions, block_ws): block = getattr(self, f'b{res}') x, img = block(x, img, cur_ws, **block_kwargs) return img
def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): assert noise_mode in ['random', 'const', 'none'] in_resolution = self.resolution // self.up misc.assert_shape( x, [None, self.weight.shape[1], in_resolution, in_resolution]) styles = self.affine(w) noise = None if self.cfg.use_noise and noise_mode == 'random': noise = torch.randn( [x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength if self.cfg.use_noise and noise_mode == 'const': noise = self.noise_const * self.noise_strength flip_weight = (self.up == 1) # slightly faster if self.instance_norm: x = x / (x.std(dim=[2, 3], keepdim=True) + 1e-8 ) # [batch_size, c, h, w] if self.cfg.fmm.enabled: x = fmm_modulate_linear(x=x, weight=self.weight, styles=styles, noise=noise, activation=self.cfg.fmm.activation) else: x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) return x
def forward(self, batch_size: int, shifts: Optional[Tensor]=None) -> Tensor: x = self.const_input.unsqueeze(0).repeat([batch_size, 1, 1, 1]) # [b, c, h, w] if shifts is not None: misc.assert_shape(shifts, [batch_size, 2]) assert shifts.max().item() <= 1.0 assert shifts.min().item() >= -1.0 coords = generate_coords(batch_size, self.const_input.shape[1], device=x.device, align_corners=True) # [b, 2, h, w] # # Applying the shift # coords = coords + shifts.view(batch_size, 2, 1, 1) # [b, 2, h, w] # # Converting into F.grid_sample coords: # # 1. Convert the range # coords = coords + 1 # [-1, 1] => [0, 2] # # 2. Perform padding_mode=replicate # # coords[coords > 0] = coords[coords > 0] % (2 + 1e-12) # # coords[coords < 0] = -(-coords[coords < 0] % 2) + 2 + (1e-12) # # 3. Convert back to [-1, 1] range # coords = coords - 1 # [0, 2] => [-1, 1] # # 4. F.grid_sample uses flipped coordinates (TODO: should we too?) # coords[:, 1] = coords[:, 1] * -1.0 # # 5. It also uses different shape # coords = coords.permute(0, 2, 3, 1) # [b, h, w, 2] # Performing a slower, but less error-prone approach # (convert shifts from [-1, 1] to [-2, 2], so we are now [-3, 3]) coords = coords + 2 * shifts.view(batch_size, 2, 1, 1) # [b, 2, h, w] coords = coords / 3 # [-3, 3] => [-1, 1] range coords = coords.permute(0, 2, 3, 1) assert coords.min().item() >= -1 assert coords.max().item() <= 1 x = torch.cat([x, x, x], dim=3) # [b, c, h, w * 3] x = F.grid_sample(x, coords, mode='bilinear', align_corners=True) # [b, c, h, w] # torch.save(coords.detach().cpu(), '/tmp/trash/coords') # torch.save(x.detach().cpu(), '/tmp/trash/x') # torch.save(self.const_input.detach().cpu(), '/tmp/trash/const_input') # assert torch.allclose(x[0], self.const_input, atol=1e-4) return x
def fmm_modulate( conv_weight: Tensor, fmm_weights: nn.Module, fmm_mod_type: str='mult', demodulate: bool=False, fmm_add_weight: float=1.0, activation: Optional[str]=None) -> Tensor: """ Applies FMM fmm_weights to a given conv weight tensor """ batch_size, out_channels, in_channels, kh, kw = conv_weight.shape assert fmm_weights.shape[1] % (in_channels + out_channels) == 0 rank = fmm_weights.shape[1] // (in_channels + out_channels) lhs = fmm_weights[:, : rank * out_channels].view(batch_size, out_channels, rank) rhs = fmm_weights[:, rank * out_channels :].view(batch_size, rank, in_channels) modulation = lhs @ rhs # [batch_size, out_channels, in_channels] modulation = modulation / np.sqrt(rank) misc.assert_shape(modulation, [batch_size, out_channels, in_channels]) modulation = modulation.unsqueeze(3).unsqueeze(4) # [batch_size, out_channels, in_channels, 1, 1] if activation == "tanh": modulation = modulation.tanh() elif activation in ['linear', None]: pass elif activation == 'sigmoid': modulation = modulation.sigmoid() - 0.5 else: raise NotImplementedError if fmm_mod_type == 'mult': out = conv_weight * (modulation + 1.0) elif fmm_mod_type == 'add': out = conv_weight + fmm_add_weight * modulation else: raise NotImplementedError if demodulate: out = out / out.norm(dim=[2, 3, 4], keepdim=True) return out
def forward(self, x, img, cmap, force_fp32=False): misc.assert_shape( x, [None, self.in_channels, self.resolution, self.resolution ]) # [NCHW] _ = force_fp32 # unused dtype = torch.float32 memory_format = torch.contiguous_format # FromRGB. x = x.to(dtype=dtype, memory_format=memory_format) if self.architecture == 'skip': misc.assert_shape( img, [None, self.img_channels, self.resolution, self.resolution]) img = img.to(dtype=dtype, memory_format=memory_format) x = x + self.fromrgb(img) # Main layers. if self.mbstd is not None: x = self.mbstd(x) x = self.conv(x) x = self.fc(x.flatten(1)) x = self.out(x) # Conditioning. if self.cmap_dim > 0: misc.assert_shape(cmap, [None, self.cmap_dim]) x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) assert x.dtype == dtype return x
def fast_bilinear_mult_row(x: Tensor, styles: Tensor, shifts: Optional[Tensor]=None) -> Tensor: b, c, h, w = x.shape context_size = 2 misc.assert_shape(styles, [b, c, context_size + 1]) centers = shifts if centers is None: centers = torch.zeros(b, 2, dtype=styles.dtype, device=styles.device) misc.assert_shape(centers, [b, 2]) assert centers.min().item() >= -1.0 assert centers.max().item() >= -1.0 # Centers are [-1, 1] range, but w_before/w_after positions correspond to -2/2. # Constructing the bounds for each center # The size of the square is 2: it is in [-1, 1] x [-1, 1] # Bounds correspond to left and right borders assert context_size == 2 bounds = torch.stack([ torch.stack([centers[:, 0] - 1, centers[:, 1]], dim=1), torch.stack([centers[:, 0] + 1, centers[:, 1]], dim=1) ], dim=1) # [b, 2, 2] bounds = bounds.unsqueeze(1) # [b, 1, 2, 2] == [b, h, w, 2] # Now, grid sample assume [-1, 1] range, so adjust: bounds.mul_(0.5) # Also, for F.grid_sample we need to flip y coordinate bounds[:, :, :, 1].mul_(-1.0) # Now, we can get our interpolated embeddings w_bounds = F.grid_sample(styles.unsqueeze(2), bounds.to(styles.dtype), mode='bilinear', align_corners=True) # [b, c, 1, 2] # Now, we can interpolate and modulate modulation = F.interpolate(w_bounds, size=(1, w), mode='bilinear', align_corners=True) # [b, c, 1, w] x = x * modulation # [b, c, h, w] # print('PERFORMED fast_bilinear_mult_row') return x
def forward(self, x, w, mask, noise_mode='random', fused_modconv=True, gain=1): assert noise_mode in ['random', 'const', 'none'] in_resolution = self.resolution // self.up misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution]) w_n, w_m, _ = w.shape styles = self.affine(w.view([w_n * w_m, -1])) noise = None if self.use_noise and noise_mode == 'random': noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength if self.use_noise and noise_mode == 'const': noise = self.noise_const * self.noise_strength flip_weight = (self.up == 1) # slightly faster act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = x.repeat_interleave(w_m, 0) x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) return x.view(w_n, w_m, *x.shape[1:]).mul(mask.unsqueeze(2)).sum(1)
def fast_manual_bilinear_mult_row(x: Tensor, styles: Tensor, left_borders_idx: Tensor, grid_size: int, w_coord_dist: float, w_lerp_multiplier: float=1.0) -> Tensor: b, c, h, w = x.shape misc.assert_shape(styles, [b, 3, c]) misc.assert_shape(left_borders_idx, [b]) w_dist = int(0.5 * w_coord_dist * w) interp_coefs = torch.linspace(1 / (2 * w_dist), 1 - 1 / (2 * w_dist), w_dist, device=x.device, dtype=styles.dtype) # [w_dist] interp_coefs = interp_coefs * w_lerp_multiplier interp_coefs = interp_coefs.view(1, w_dist, 1) # [1, w_dist, 1] styles_grid_left = styles[:, 0].unsqueeze(1) * (w_lerp_multiplier - interp_coefs) + styles[:, 1].unsqueeze(1) * interp_coefs # [b, w_dist, c] styles_grid_right = styles[:, 1].unsqueeze(1) * (w_lerp_multiplier - interp_coefs) + styles[:, 2].unsqueeze(1) * interp_coefs # [b, w_dist, c] styles_grid = torch.cat([styles_grid_left, styles_grid_right], dim=1).to(x.dtype) # [b, 2 * w_dist, c] # Left borders were randomly sampled in [0, 2 * w_dist - w] integer range # We use them to select the corresponding styles patch_size = w // grid_size batch_idx = torch.arange(b, device=x.device).view(-1, 1).repeat(1, w) # [b, w] grid_idx = (left_borders_idx.unsqueeze(1) * patch_size) + torch.arange(w, device=x.device).view(1, -1) # [b, w] latents = styles_grid[batch_idx, grid_idx].permute(0, 2, 1) # [b, c, w] x = x * latents.unsqueeze(2) # [b, c, h, w] return x
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs): misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) w_iter = iter(ws.unbind(dim=1)) dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format if fused_modconv is None: with misc.suppress_tracer_warnings(): # this value will be treated as a constant fused_modconv = (not self.training) and (dtype == torch.float32 or (isinstance(x, Tensor) and int(x.shape[0]) == 1)) # Input. if self.in_channels == 0: conv1_w = next(w_iter) x = self.input(ws.shape[0], conv1_w, device=ws.device, dtype=dtype, memory_format=memory_format) else: misc.assert_shape(x, [None, self.in_channels, self.input_resolution, self.input_resolution]) x = x.to(dtype=dtype, memory_format=memory_format) x = maybe_upsample(x, self.cfg.upsampling_mode, self.up) # Main layers. if self.in_channels == 0: x = self.conv1(x, conv1_w, fused_modconv=fused_modconv, **layer_kwargs) elif self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) x = y.add_(x) else: conv0_w = next(w_iter) if self.coord_fuser is not None: x = self.coord_fuser(x, conv0_w, dtype=dtype, memory_format=memory_format) x = self.conv0(x, conv0_w, fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) if not self.extra_convs is None: for conv, w in zip(self.extra_convs, w_iter): x = conv(x, w, fused_modconv=fused_modconv, **layer_kwargs) # ToRGB. if img is not None: misc.assert_shape(img, [None, self.img_channels, self.input_resolution, self.input_resolution]) if self.up == 2: if self.cfg.upsampling_mode is None: img = upfirdn2d.upsample2d(img, self.resample_filter) else: img = maybe_upsample(img, self.cfg.upsampling_mode, 2) if self.is_last or self.architecture == 'skip': y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) img = img.add_(y) if img is not None else y assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img
def fast_bilinear_mult(x, styles): """ x: [b, c, h, w], styles: [b, c, 2, 2] """ b, c, h, w = x.shape misc.assert_shape(styles, [b, c, 2, 2]) kwargs = dict(device=x.device, dtype=x.dtype) top_to_bottom = torch.linspace(1, 0, h, **kwargs).unsqueeze(1) left_to_right = torch.linspace(1, 0, w, **kwargs).unsqueeze(0) coefs_11 = top_to_bottom * left_to_right # [h, w] coefs_12 = top_to_bottom * (1.0 - left_to_right) # [h, w] coefs_21 = (1.0 - top_to_bottom) * left_to_right # [h, w] coefs_22 = (1.0 - top_to_bottom) * (1.0 - left_to_right) # [h, w] coefs = torch.stack([coefs_11, coefs_12, coefs_21, coefs_22]) # [4, h, w] coefs = coefs.unsqueeze(0).unsqueeze(2) # [1, 4, 1, h, w] xs = (x.unsqueeze(1) * coefs) # [b, 4, c, h, w] styles = styles.permute(0, 2, 3, 1).view(b, 4, c) # [b, 4, c] styles = styles.view(b, 4, c, 1, 1) # [b, 4, c, 1, 1] y = (xs * styles).sum(dim=1) # [b, c, h, w] return y
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs): misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) w_iter = iter(ws.unbind(dim=1)) dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format if fused_modconv is None: with misc.suppress_tracer_warnings(): # this value will be treated as a constant fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) # Input. if self.in_channels == 0: x = self.const.to(dtype=dtype, memory_format=memory_format) x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) else: # !!! custom misc.assert_shape(x, [None, self.in_channels, self.resolution * self.init_res[0] // 8, self.resolution * self.init_res[1] // 8]) # misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) x = x.to(dtype=dtype, memory_format=memory_format) # Main layers. if self.in_channels == 0: x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) elif self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) x = y.add_(x) else: x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # ToRGB. if img is not None: # !!! custom misc.assert_shape(img, [None, self.img_channels, self.resolution * self.init_res[0] // 8, self.resolution * self.init_res[1] // 8]) # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) img = upfirdn2d.upsample2d(img, self.resample_filter) if self.is_last or self.architecture == 'skip': y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) img = img.add_(y) if img is not None else y assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs): _ = update_emas # unused misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) w_iter = iter(ws.unbind(dim=1)) if ws.device.type != 'cuda': force_fp32 = True dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format if fused_modconv is None: fused_modconv = self.fused_modconv_default if fused_modconv == 'inference_only': fused_modconv = (not self.training) # Input. if self.in_channels == 0: x = self.const.to(dtype=dtype, memory_format=memory_format) x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) else: misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) x = x.to(dtype=dtype, memory_format=memory_format) # Main layers. if self.in_channels == 0: x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) elif self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) x = y.add_(x) else: x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # ToRGB. if img is not None: misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) img = upfirdn2d.upsample2d(img, self.resample_filter) if self.is_last or self.architecture == 'skip': y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) img = img.add_(y) if img is not None else y assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img
def forward(self, batch_size: int, w: Tensor, w_context: Tensor, left_borders_idx: Tensor) -> Tensor: misc.assert_shape(w, [batch_size, self.w_dim]) misc.assert_shape(w_context, [batch_size, 2, self.w_dim]) misc.assert_shape(left_borders_idx, [batch_size]) # Computing the global features w_all = torch.stack([w_context[:, 0], w, w_context[:, 1]], dim=1) # [b, 3, w_dim] styles = self.affine(w_all.view(-1, self.w_dim)).view(batch_size, 3, self.channel_dim) # [b, 2, c] raw_const_inputs = self.input_column.unsqueeze(0).unsqueeze(3).repeat(batch_size, 1, 1, self.resolution) # [b, c, h, w] latents = fast_manual_bilinear_mult_row(raw_const_inputs, styles, left_borders_idx, self.grid_size, self.w_coord_dist, self.w_lerp_multiplier) # Ok, now for each cell in the grid we need to compute its high-frequency coordinates # Otherwise, it will be too difficult for the model to understand the relative positions coords = generate_shifted_coords(left_borders_idx, self.resolution, self.grid_size, self.w_coord_dist, device=w.device) bases = self.basis.unsqueeze(0).repeat(batch_size, 1, 1) # [batch_size, dim, 2] raw_coord_embs = torch.einsum('bdc,bcxy->bdxy', bases, coords) # [batch_size, dim, img_size, img_size] coord_embs = torch.cat([raw_coord_embs.sin(), raw_coord_embs.cos()], dim=1) # [batch_size, dim * 2, img_size, img_size] # Computing final inputs inputs = torch.cat([latents, coord_embs], dim=1) # [b, c, grid_size, grid_size] return inputs
def forward(self, x: Tensor, w: Tensor = None, dtype=None, memory_format=None) -> Tensor: """ Dims: @arg x is [batch_size, in_channels, img_size, img_size] @arg w is [batch_size, w_dim] @return out is [batch_size, in_channels + fourier_dim + cips_dim, img_size, img_size] """ assert memory_format is torch.contiguous_format if self.cfg.fallback: return x batch_size, in_channels, img_size = x.shape[:3] out = x if self.use_full_cache and (not self._full_cache is None) and (self._full_cache.device == x.device) and \ (self._full_cache.shape == (batch_size, self.get_total_dim(), img_size, img_size)): return torch.cat([x, self._full_cache], dim=1) if (not self._fourier_embs_cache is None) and (self._fourier_embs_cache.device == x.device) and \ (self._fourier_embs_cache.shape == (batch_size, self.get_total_dim() - self.const_emb_size, img_size, img_size)): out = torch.cat([out, self._fourier_embs_cache], dim=1) else: raw_embs = [] raw_coords = generate_coords( batch_size, img_size, x.device) # [batch_size, coord_dim, img_size, img_size] if self.use_raw_coords: out = torch.cat([out, raw_coords], dim=1) if self.log_emb_size > 0: log_bases = self.log_basis.unsqueeze(0).repeat( batch_size, 1, 1) # [batch_size, log_emb_size, 2] raw_log_embs = torch.einsum( 'bdc,bcxy->bdxy', log_bases, raw_coords ) # [batch_size, log_emb_size, img_size, img_size] raw_embs.append(raw_log_embs) if self.random_emb_size > 0: random_bases = self.random_basis.unsqueeze(0).repeat( batch_size, 1, 1) # [batch_size, random_emb_size, 2] raw_random_embs = torch.einsum( 'bdc,bcxy->bdxy', random_bases, raw_coords ) # [batch_size, random_emb_size, img_size, img_size] raw_embs.append(raw_random_embs) if self.shared_emb_size > 0: shared_bases = self.shared_basis.unsqueeze(0).repeat( batch_size, 1, 1) # [batch_size, shared_emb_size, 2] raw_shared_embs = torch.einsum( 'bdc,bcxy->bdxy', shared_bases, raw_coords ) # [batch_size, shared_emb_size, img_size, img_size] raw_embs.append(raw_shared_embs) if self.predictable_emb_size > 0: misc.assert_shape(w, [batch_size, None]) mod = self.affine(w) # [batch_size, W_size + b_size] W = self.fourier_scale * mod[:, :self. W_size] # [batch_size, W_size] W = W.view(batch_size, self.predictable_emb_size, self.cfg.coord_dim ) # [batch_size, predictable_emb_size, coord_dim] bias = mod[:, self.W_size:].view( batch_size, self.predictable_emb_size, 1, 1) # [batch_size, predictable_emb_size, 1] raw_predictable_embs = ( torch.einsum('bdc,bcxy->bdxy', W, raw_coords) + bias ) # [batch_size, predictable_emb_size, img_size, img_size] raw_embs.append(raw_predictable_embs) if len(raw_embs) > 0: raw_embs = torch.cat( raw_embs, dim=1 ) # [batch_suze, log_emb_size + random_emb_size + predictable_emb_size, img_size, img_size] raw_embs = raw_embs.contiguous( ) # [batch_suze, -1, img_size, img_size] out = torch.cat([ out, raw_embs.sin().to(dtype=dtype, memory_format=memory_format) ], dim=1) # [batch_size, -1, img_size, img_size] if self.use_cosine: out = torch.cat( [ out, raw_embs.cos().to(dtype=dtype, memory_format=memory_format) ], dim=1) # [batch_size, -1, img_size, img_size] if self.predictable_emb_size == 0 and self.shared_emb_size == 0 and out.shape[ 1] > x.shape[1]: self._fourier_embs_cache = out[:, x.shape[1]:].detach() if self.const_emb_size > 0: const_embs = self.const_embs.repeat([batch_size, 1, 1, 1]) const_embs = const_embs.to(dtype=dtype, memory_format=memory_format) out = torch.cat( [out, const_embs], dim=1) # [batch_size, total_dim, img_size, img_size] if self.use_full_cache and self.predictable_emb_size == 0 and self.shared_emb_size == 0 and out.shape[ 1] > x.shape[1]: self._full_cache = out[:, x.shape[1]:].detach() return out
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs): misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) w_iter = iter(ws.unbind(dim=1)) dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format if fused_modconv is None: with misc.suppress_tracer_warnings( ): # this value will be treated as a constant fused_modconv = (not self.training) and ( dtype == torch.float32 or int(x.shape[0]) == 1) # Input. if self.in_channels == 0: x = self.const.to(dtype=dtype, memory_format=memory_format) x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) else: misc.assert_shape(x, [ None, self.in_channels, self.resolution // 2, self.resolution // 2 ]) x = x.to(dtype=dtype, memory_format=memory_format) # Main layers. if self.in_channels == 0: x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) elif self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) x = y.add_(x) else: x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # ToRGB. if img is not None: misc.assert_shape(img, [ None, self.img_channels + self.segmentation_channels, self.resolution // 2, self.resolution // 2 ]) img = upfirdn2d.upsample2d(img, self.resample_filter) if self.is_last or self.architecture == 'skip': w_temp = next(w_iter) rgb = self.torgb(x, w_temp, fused_modconv=fused_modconv) rgb = rgb.to(dtype=torch.float32, memory_format=torch.contiguous_format) segmentation = self.tosegmentation(x, w_temp, fused_modconv=fused_modconv) newImg = torch.cat((rgb, segmentation), dim=1) img = img.add_(newImg) if img is not None else newImg if self.is_last: originalSegmentation = img[:, 3:] maxs = torch.max(originalSegmentation, dim=1)[0].unsqueeze(1) afterSubtraction = originalSegmentation - maxs + self.eps finalArray = torch.round( torch.max(self.zeroTensor, afterSubtraction) / self.eps) img[:, 3:] = finalArray assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img
def modulated_conv2d( x, # input, shape=[batch_size, in_channels, in_height, in_width] weight, # weights, shape=[out_channels, in_channels, kernel_height, kernel_width] styles, # modulation co-efficients, shape=[batch_size, in_channels] noise=None, # to add noise to the output activations up=1, # upsampling factpr down=1, # downsampling factor padding=0, # padding as per upsampled image resample_filter=None, demodulate=True, # Weight demodulation flip_weight=True, fused_modconv=True, # To perform modulation ): batch_size = x.shape[0] out_channels, in_channels, kh, kw = weight.shape misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) misc.assert_shape(x, [batch_size, in_channels, None, None]) misc.assert_shape(styles, [batch_size, in_channels]) # Normalize inputs if x.dtype == torch.float16 and demodulate: weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm( float('inf'), dim=[1, 2, 3], keepdim=True)) styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # Calculate sample weights and demodultion coefficients w = None demod_coeff = None if demodulate or fused_modconv: w = weight.unsqueeze(0) w = w + styles.reshape(batch_size, 1, -1, 1, 1) if demodulate: demod_coeff = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() if demodulate and fused_modconv: w = w * demod_coeff.reshape(batch_size, -1, 1, 1, 1) # Modulation execution by scaling activations if not fused_modconv: x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) x = conv2d_resample.conv2d_resample( x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight, ) if demodulate and noise is not None: x = fma.fma(x, demod_coeff.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) elif demodulate: x = x * demod_coeff.to(x.dtype).reshape(batch_size, -1, 1, 1) elif noise is not None: x = x.add_(noise.to(x.dtype)) return x with misc.suppress_tracer_warnings(): batch_size = int(batch_size) misc.assert_shape(x, [batch_size, in_channels, None, None]) x = x.reshape(1, 1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_resample.conv2d_resample( x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight, ) x = x.reshape(batch_size, -1, *x.shape[2:]) if noise is not None: x = x.add_(noise) return x