def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs): misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) w_iter = iter(ws.unbind(dim=1)) dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format if fused_modconv is None: with misc.suppress_tracer_warnings(): # this value will be treated as a constant fused_modconv = (not self.training) and (dtype == torch.float32 or (isinstance(x, Tensor) and int(x.shape[0]) == 1)) # Input. if self.in_channels == 0: conv1_w = next(w_iter) x = self.input(ws.shape[0], conv1_w, device=ws.device, dtype=dtype, memory_format=memory_format) else: misc.assert_shape(x, [None, self.in_channels, self.input_resolution, self.input_resolution]) x = x.to(dtype=dtype, memory_format=memory_format) x = maybe_upsample(x, self.cfg.upsampling_mode, self.up) # Main layers. if self.in_channels == 0: x = self.conv1(x, conv1_w, fused_modconv=fused_modconv, **layer_kwargs) elif self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) x = y.add_(x) else: conv0_w = next(w_iter) if self.coord_fuser is not None: x = self.coord_fuser(x, conv0_w, dtype=dtype, memory_format=memory_format) x = self.conv0(x, conv0_w, fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) if not self.extra_convs is None: for conv, w in zip(self.extra_convs, w_iter): x = conv(x, w, fused_modconv=fused_modconv, **layer_kwargs) # ToRGB. if img is not None: misc.assert_shape(img, [None, self.img_channels, self.input_resolution, self.input_resolution]) if self.up == 2: if self.cfg.upsampling_mode is None: img = upfirdn2d.upsample2d(img, self.resample_filter) else: img = maybe_upsample(img, self.cfg.upsampling_mode, 2) if self.is_last or self.architecture == 'skip': y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) img = img.add_(y) if img is not None else y assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img
def modulated_conv2d( x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. styles, # Modulation coefficients of shape [batch_size, in_channels]. noise = None, # Optional noise tensor to add to the output activations. up = 1, # Integer upsampling factor. down = 1, # Integer downsampling factor. padding = 0, # Padding with respect to the upsampled image. resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). demodulate = True, # Apply weight demodulation? flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? ): batch_size = x.shape[0] out_channels, in_channels, kh, kw = weight.shape misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] misc.assert_shape(styles, [batch_size, in_channels]) # [NI] # Pre-normalize inputs to avoid FP16 overflow. if x.dtype == torch.float16 and demodulate: weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I # Calculate per-sample weights and demodulation coefficients. w = None dcoefs = None if demodulate or fused_modconv: w = weight.unsqueeze(0) # [NOIkk] w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] if demodulate: dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] if demodulate and fused_modconv: w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] # Execute by scaling the activations before and after the convolution. if not fused_modconv: x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) if demodulate and noise is not None: x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) elif demodulate: x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) elif noise is not None: x = x.add_(noise.to(x.dtype)) return x # Execute as one fused op using grouped convolution. with misc.suppress_tracer_warnings(): # this value will be treated as a constant batch_size = int(batch_size) misc.assert_shape(x, [batch_size, in_channels, None, None]) x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) x = x.reshape(batch_size, -1, *x.shape[2:]) if noise is not None: x = x.add_(noise) return x
def forward(self, x): N, C, H, W = x.shape with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N F = self.num_channels c = C // F y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels. y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. return x
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs): misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) w_iter = iter(ws.unbind(dim=1)) dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format if fused_modconv is None: with misc.suppress_tracer_warnings(): # this value will be treated as a constant fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) # Input. if self.in_channels == 0: x = self.const.to(dtype=dtype, memory_format=memory_format) x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) else: # !!! custom misc.assert_shape(x, [None, self.in_channels, self.resolution * self.init_res[0] // 8, self.resolution * self.init_res[1] // 8]) # misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) x = x.to(dtype=dtype, memory_format=memory_format) # Main layers. if self.in_channels == 0: x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) elif self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) x = y.add_(x) else: x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # ToRGB. if img is not None: # !!! custom misc.assert_shape(img, [None, self.img_channels, self.resolution * self.init_res[0] // 8, self.resolution * self.init_res[1] // 8]) # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) img = upfirdn2d.upsample2d(img, self.resample_filter) if self.is_last or self.architecture == 'skip': y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) img = img.add_(y) if img is not None else y assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img
def modulated_conv2d( x, # input, shape=[batch_size, in_channels, in_height, in_width] weight, # weights, shape=[out_channels, in_channels, kernel_height, kernel_width] styles, # modulation co-efficients, shape=[batch_size, in_channels] noise=None, # to add noise to the output activations up=1, # upsampling factpr down=1, # downsampling factor padding=0, # padding as per upsampled image resample_filter=None, demodulate=True, # Weight demodulation flip_weight=True, fused_modconv=True, # To perform modulation ): batch_size = x.shape[0] out_channels, in_channels, kh, kw = weight.shape misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) misc.assert_shape(x, [batch_size, in_channels, None, None]) misc.assert_shape(styles, [batch_size, in_channels]) # Normalize inputs if x.dtype == torch.float16 and demodulate: weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm( float('inf'), dim=[1, 2, 3], keepdim=True)) styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # Calculate sample weights and demodultion coefficients w = None demod_coeff = None if demodulate or fused_modconv: w = weight.unsqueeze(0) w = w + styles.reshape(batch_size, 1, -1, 1, 1) if demodulate: demod_coeff = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() if demodulate and fused_modconv: w = w * demod_coeff.reshape(batch_size, -1, 1, 1, 1) # Modulation execution by scaling activations if not fused_modconv: x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) x = conv2d_resample.conv2d_resample( x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight, ) if demodulate and noise is not None: x = fma.fma(x, demod_coeff.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) elif demodulate: x = x * demod_coeff.to(x.dtype).reshape(batch_size, -1, 1, 1) elif noise is not None: x = x.add_(noise.to(x.dtype)) return x with misc.suppress_tracer_warnings(): batch_size = int(batch_size) misc.assert_shape(x, [batch_size, in_channels, None, None]) x = x.reshape(1, 1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_resample.conv2d_resample( x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight, ) x = x.reshape(batch_size, -1, *x.shape[2:]) if noise is not None: x = x.add_(noise) return x
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs): misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) w_iter = iter(ws.unbind(dim=1)) dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format if fused_modconv is None: with misc.suppress_tracer_warnings( ): # this value will be treated as a constant fused_modconv = (not self.training) and ( dtype == torch.float32 or int(x.shape[0]) == 1) # Input. if self.in_channels == 0: x = self.const.to(dtype=dtype, memory_format=memory_format) x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) else: misc.assert_shape(x, [ None, self.in_channels, self.resolution // 2, self.resolution // 2 ]) x = x.to(dtype=dtype, memory_format=memory_format) # Main layers. if self.in_channels == 0: x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) elif self.architecture == 'resnet': y = self.skip(x, gain=np.sqrt(0.5)) x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) x = y.add_(x) else: x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) # ToRGB. if img is not None: misc.assert_shape(img, [ None, self.img_channels + self.segmentation_channels, self.resolution // 2, self.resolution // 2 ]) img = upfirdn2d.upsample2d(img, self.resample_filter) if self.is_last or self.architecture == 'skip': w_temp = next(w_iter) rgb = self.torgb(x, w_temp, fused_modconv=fused_modconv) rgb = rgb.to(dtype=torch.float32, memory_format=torch.contiguous_format) segmentation = self.tosegmentation(x, w_temp, fused_modconv=fused_modconv) newImg = torch.cat((rgb, segmentation), dim=1) img = img.add_(newImg) if img is not None else newImg if self.is_last: originalSegmentation = img[:, 3:] maxs = torch.max(originalSegmentation, dim=1)[0].unsqueeze(1) afterSubtraction = originalSegmentation - maxs + self.eps finalArray = torch.round( torch.max(self.zeroTensor, afterSubtraction) / self.eps) img[:, 3:] = finalArray assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img