def modulated_conv2d( x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. styles, # Modulation coefficients of shape [batch_size, in_channels]. noise = None, # Optional noise tensor to add to the output activations. up = 1, # Integer upsampling factor. down = 1, # Integer downsampling factor. padding = 0, # Padding with respect to the upsampled image. resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). demodulate = True, # Apply weight demodulation? flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? ): batch_size = x.shape[0] out_channels, in_channels, kh, kw = weight.shape misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] misc.assert_shape(styles, [batch_size, in_channels]) # [NI] # Pre-normalize inputs to avoid FP16 overflow. if x.dtype == torch.float16 and demodulate: weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I # Calculate per-sample weights and demodulation coefficients. w = None dcoefs = None if demodulate or fused_modconv: w = weight.unsqueeze(0) # [NOIkk] w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] if demodulate: dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] if demodulate and fused_modconv: w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] # Execute by scaling the activations before and after the convolution. if not fused_modconv: x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) if demodulate and noise is not None: x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) elif demodulate: x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) elif noise is not None: x = x.add_(noise.to(x.dtype)) return x # Execute as one fused op using grouped convolution. with misc.suppress_tracer_warnings(): # this value will be treated as a constant batch_size = int(batch_size) misc.assert_shape(x, [batch_size, in_channels, None, None]) x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) x = x.reshape(batch_size, -1, *x.shape[2:]) if noise is not None: x = x.add_(noise) return x
def modulated_conv2d( x, # input, shape=[batch_size, in_channels, in_height, in_width] weight, # weights, shape=[out_channels, in_channels, kernel_height, kernel_width] styles, # modulation co-efficients, shape=[batch_size, in_channels] noise=None, # to add noise to the output activations up=1, # upsampling factpr down=1, # downsampling factor padding=0, # padding as per upsampled image resample_filter=None, demodulate=True, # Weight demodulation flip_weight=True, fused_modconv=True, # To perform modulation ): batch_size = x.shape[0] out_channels, in_channels, kh, kw = weight.shape misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) misc.assert_shape(x, [batch_size, in_channels, None, None]) misc.assert_shape(styles, [batch_size, in_channels]) # Normalize inputs if x.dtype == torch.float16 and demodulate: weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm( float('inf'), dim=[1, 2, 3], keepdim=True)) styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # Calculate sample weights and demodultion coefficients w = None demod_coeff = None if demodulate or fused_modconv: w = weight.unsqueeze(0) w = w + styles.reshape(batch_size, 1, -1, 1, 1) if demodulate: demod_coeff = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() if demodulate and fused_modconv: w = w * demod_coeff.reshape(batch_size, -1, 1, 1, 1) # Modulation execution by scaling activations if not fused_modconv: x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) x = conv2d_resample.conv2d_resample( x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight, ) if demodulate and noise is not None: x = fma.fma(x, demod_coeff.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) elif demodulate: x = x * demod_coeff.to(x.dtype).reshape(batch_size, -1, 1, 1) elif noise is not None: x = x.add_(noise.to(x.dtype)) return x with misc.suppress_tracer_warnings(): batch_size = int(batch_size) misc.assert_shape(x, [batch_size, in_channels, None, None]) x = x.reshape(1, 1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_resample.conv2d_resample( x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight, ) x = x.reshape(batch_size, -1, *x.shape[2:]) if noise is not None: x = x.add_(noise) return x