def layer(x, layer_idx, size, fmaps, kernel, up=False): x = modulated_conv2d_layer(x, dlatents_in[:, layer_idx], fmaps=fmaps, kernel=kernel, up=up, resample_kernel=resample_kernel, fused_modconv=fused_modconv, impl=impl) if size is not None and up is True: x = fix_size(x, size, scale_type) # multi latent blending x = multimask(x, size, latmask, countH, countW, splitfine) if randomize_noise: noise = tf.random_normal( [tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype) else: noise = tf.cast(noise_inputs[layer_idx], x.dtype) noise = fix_size(noise, (x.shape[2], x.shape[3]), scale_type=scale_type) noise_strength = tf.get_variable('noise_strength', shape=[], initializer=tf.initializers.zeros()) x += noise * tf.cast(noise_strength, x.dtype) return apply_bias_act(x, act=act, impl=impl)
def block(x, res, size): # res = 3..res_log2 t = x with tf.variable_scope('Conv0_up'): x = layer(x, layer_idx=res * 2 - 5, size=size, fmaps=nf(res - 1), kernel=3, up=True) with tf.variable_scope('Conv1'): x = layer(x, layer_idx=res * 2 - 4, size=size, fmaps=nf(res - 1), kernel=3) if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 1), kernel=1, up=True, resample_kernel=resample_kernel, impl=impl) if size is not None: t = fix_size(t, (x.shape[2], x.shape[3]), scale_type=scale_type) x = (x + t) * (1 / np.sqrt(2)) return x
def forward(self, x, latmask, w, noise_mode='random', fused_modconv=True, gain=1): # def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): assert noise_mode in ['random', 'const', 'none'] in_resolution = self.resolution // self.up # misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution]) styles = self.affine(w) noise = None if self.use_noise and noise_mode == 'random': # !!! custom sz = self.size if self.up == 2 and self.size is not None else x.shape[ 2:] noise = torch.randn([x.shape[0], 1, *sz], device=x.device) * self.noise_strength # noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength if self.use_noise and noise_mode == 'const': noise = self.noise_const * self.noise_strength # !!! custom noise size noise_size = self.size if self.up == 2 and self.size is not None and self.resolution > 4 else x.shape[ 2:] noise = fix_size(noise.unsqueeze(0).unsqueeze(0), noise_size, scale_type=self.scale_type)[0][0] # print(x.shape, noise.shape, self.size, self.up) flip_weight = (self.up == 1) # slightly faster x = modulated_conv2d( x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, latmask=latmask, countHW=self.countHW, splitfine=self.splitfine, size=self.size, scale_type=self.scale_type, # !!! custom padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) return x
def forward(self, x, img, ws, latmask, dconst, force_fp32=False, fused_modconv=None, **layer_kwargs): # def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs): misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) w_iter = iter(ws.unbind(dim=1)) dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format if fused_modconv is None: with misc.suppress_tracer_warnings( ): # this value will be treated as a constant fused_modconv = (not self.training) and ( dtype == torch.float32 or int(x.shape[0]) == 1) # Input. if self.in_channels == 0: x = self.const.to(dtype=dtype, memory_format=memory_format) x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) # !!! custom const size if 'side' in self.scale_type and 'symm' in self.scale_type: # looks better const_size = self.init_res if self.size is None else self.size x = fix_size(x, const_size, self.scale_type) # distortion technique from Aydao x += dconst else: # misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) x = x.to(dtype=dtype, memory_format=memory_format) # Main layers. if self.in_channels == 0: # !!! custom latmask x = self.conv1(x, None, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) elif self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) # !!! custom latmask x = self.conv0(x, latmask, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, None, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) # x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) x = y.add_(x) else: # !!! custom latmask x = self.conv0(x, latmask, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, None, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # ToRGB. if img is not None: # !!! custom img size # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) img = upfirdn2d.upsample2d(img, self.resample_filter) img = fix_size(img, self.size, scale_type=self.scale_type) if self.is_last or self.architecture == 'skip': y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) img = img.add_(y) if img is not None else y assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img
def modulated_conv2d( x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. styles, # Modulation coefficients of shape [batch_size, in_channels]. # !!! custom latmask, # mask for split-frame latents blending countHW=[1, 1], # frame split count by height,width splitfine=0., # frame split edge fineness (float from 0+) size=None, # custom size scale_type=None, # scaling way: fit, centr, side, pad, padside noise=None, # Optional noise tensor to add to the output activations. up=1, # Integer upsampling factor. down=1, # Integer downsampling factor. padding=0, # Padding with respect to the upsampled image. resample_filter=None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). demodulate=True, # Apply weight demodulation? flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation? ): batch_size = x.shape[0] out_channels, in_channels, kh, kw = weight.shape misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] misc.assert_shape(styles, [batch_size, in_channels]) # [NI] # Pre-normalize inputs to avoid FP16 overflow. if x.dtype == torch.float16 and demodulate: weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm( float('inf'), dim=[1, 2, 3], keepdim=True)) # max_Ikk styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I # Calculate per-sample weights and demodulation coefficients. w = None dcoefs = None if demodulate or fused_modconv: w = weight.unsqueeze(0) # [NOIkk] w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] if demodulate: dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO] if demodulate and fused_modconv: w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] # Execute by scaling the activations before and after the convolution. if not fused_modconv: x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) # !!! custom size & multi latent blending if size is not None and up == 2: x = fix_size(x, size, scale_type) x = multimask(x, size, latmask, countHW, splitfine) if demodulate and noise is not None: x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) elif demodulate: x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) elif noise is not None: x = x.add_(noise.to(x.dtype)) return x # Execute as one fused op using grouped convolution. with misc.suppress_tracer_warnings( ): # this value will be treated as a constant batch_size = int(batch_size) misc.assert_shape(x, [batch_size, in_channels, None, None]) x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) x = x.reshape(batch_size, -1, *x.shape[2:]) # !!! custom size & multi latent blending if size is not None and up == 2: x = fix_size(x, size, scale_type) x = multimask(x, size, latmask, countHW, splitfine) if noise is not None: x = x.add_(noise) return x
def G_synthesis_stylegan2( dlatents_in, # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size]. latmask, # mask for split-frame latents blending dconst, # initial (const) layer displacement latmask_res=[1, 1], # resolution of external mask for blending countW=1, # frame split count by width countH=1, # frame split count by height splitfine=0., # frame split edge sharpness (float from 0) size=None, # Output size scale_type=None, # scaling way: fit, centr, side, pad, padside init_res=[4, 4], # Initial (minimum) resolution for progressive training dlatent_size=512, # Disentangled latent (W) dimensionality. num_channels=3, # Number of output color channels. resolution=1024, # Base model resolution (corresponding to the layer count) fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. randomize_noise=True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables. architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. fused_modconv=True, # Implement modulated_conv2d_layer() as a single fused op? verbose=False, # impl='cuda', # Custom ops implementation - cuda (original) or ref (no compiling) **_kwargs): # Ignore unrecognized keyword args. res_log2 = int(np.log2(resolution)) assert resolution == 2**res_log2 and resolution >= 4 # calculate intermediate layers sizes for arbitrary output resolution custom_res = (resolution * init_res[0] // 4, resolution * init_res[1] // 4) if size is None: size = custom_res if init_res != [4, 4] and verbose: print(' .. init res', init_res, size) keep_first_layers = 2 if scale_type == 'fit' else None hws = hw_scales(size, custom_res, res_log2 - 2, keep_first_layers, verbose) if verbose: print(hws, '..', custom_res, res_log2 - 1) # multi latent blending latmask.set_shape([None, *latmask_res]) dconst.set_shape([None, dlatent_size, *init_res]) splitfine = tf.cast(splitfine, tf.float32) def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity num_layers = res_log2 * 2 - 2 images_out = None # Primary inputs. dlatents_in.set_shape([None, num_layers, dlatent_size]) dlatents_in = tf.cast(dlatents_in, dtype) # Noise inputs. noise_inputs = [] for layer_idx in range(num_layers - 1): res = (layer_idx + 5) // 2 shape = [1, 1, 2**(res - 2) * init_res[0], 2**(res - 2) * init_res[1]] noise_inputs.append( tf.get_variable('noise%d' % layer_idx, shape=shape, initializer=tf.initializers.random_normal(), trainable=False)) # Single convolution layer with all the bells and whistles. def layer(x, layer_idx, size, fmaps, kernel, up=False): x = modulated_conv2d_layer(x, dlatents_in[:, layer_idx], fmaps=fmaps, kernel=kernel, up=up, resample_kernel=resample_kernel, fused_modconv=fused_modconv, impl=impl) if size is not None and up is True: x = fix_size(x, size, scale_type) # multi latent blending x = multimask(x, size, latmask, countH, countW, splitfine) if randomize_noise: noise = tf.random_normal( [tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype) else: noise = tf.cast(noise_inputs[layer_idx], x.dtype) noise = fix_size(noise, (x.shape[2], x.shape[3]), scale_type=scale_type) noise_strength = tf.get_variable('noise_strength', shape=[], initializer=tf.initializers.zeros()) x += noise * tf.cast(noise_strength, x.dtype) return apply_bias_act(x, act=act, impl=impl) # Building blocks for main layers. def block(x, res, size): # res = 3..res_log2 t = x with tf.variable_scope('Conv0_up'): x = layer(x, layer_idx=res * 2 - 5, size=size, fmaps=nf(res - 1), kernel=3, up=True) with tf.variable_scope('Conv1'): x = layer(x, layer_idx=res * 2 - 4, size=size, fmaps=nf(res - 1), kernel=3) if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 1), kernel=1, up=True, resample_kernel=resample_kernel, impl=impl) if size is not None: t = fix_size(t, (x.shape[2], x.shape[3]), scale_type=scale_type) x = (x + t) * (1 / np.sqrt(2)) return x def upsample(y): with tf.variable_scope('Upsample'): return upsample_2d(y, k=resample_kernel, impl=impl) def torgb(x, y, res): # res = 2..res_log2 with tf.variable_scope('ToRGB'): t = apply_bias_act(modulated_conv2d_layer( x, dlatents_in[:, res * 2 - 3], fmaps=num_channels, kernel=1, demodulate=False, fused_modconv=fused_modconv, impl=impl), impl=impl) return t if y is None else y + t # Early layers. y = None with tf.variable_scope('4x4'): with tf.variable_scope('Const'): x = tf.get_variable('const', shape=[1, nf(1), *init_res], initializer=tf.initializers.random_normal()) x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1]) # distortion technique from Aydao x += dconst with tf.variable_scope('Conv'): x = layer(x, layer_idx=0, size=None, fmaps=nf(1), kernel=3) if architecture == 'skip': y = torgb(x, y, 2) # Main layers. for res in range(3, res_log2 + 1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): x = block(x, res, hws[res - 2]) if architecture == 'skip': y = upsample(y) if size is not None: y = fix_size(y, hws[res - 2], scale_type=scale_type) if architecture == 'skip' or res == res_log2: y = torgb(x, y, res) images_out = y assert images_out.dtype == tf.as_dtype(dtype) return tf.identity(images_out, name='images_out')