def forward(self, x, style, input_gain=None): n, c, h, w = x.shape weight = self.weight # Pre-normalize inputs to avoid FP16 overflow. if x.dtype == torch.float16 and self.demodulate: weight = weight * ( 1 / np.sqrt( self.in_channels * self.kernel_size * self.kernel_size) / weight.norm(float('inf'), dim=[1, 2, 3], keepdim=True) ) # max_Ikk style = style / style.norm( float('inf'), dim=1, keepdim=True) # max_I # process style code style = self.style_modulation(style).view(n, 1, c, 1, 1) + self.style_bias # combine weight and style weight = weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) weight = weight * demod.view(n, self.out_channels, 1, 1, 1) if input_gain is not None: # input_gain shape [batch, in_ch] input_gain = input_gain.expand(n, self.in_channels) # weight shape [batch, out_ch, in_ch, kernel_size, kernel_size] weight = weight * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) weight = weight.view(n * self.out_channels, c, self.kernel_size, self.kernel_size) weight = weight.to(x.dtype) if self.upsample: x = x.reshape(1, n * c, h, w) weight = weight.view(n, self.out_channels, c, self.kernel_size, self.kernel_size) weight = weight.transpose(1, 2).reshape(n * c, self.out_channels, self.kernel_size, self.kernel_size) x = conv_transpose2d(x, weight, padding=0, stride=2, groups=n) x = x.reshape(n, self.out_channels, *x.shape[-2:]) x = self.blur(x) elif self.downsample: x = self.blur(x) x = x.view(1, n * self.in_channels, *x.shape[-2:]) x = conv2d(x, weight, stride=2, padding=0, groups=n) x = x.view(n, self.out_channels, *x.shape[-2:]) else: x = x.reshape(1, n * c, h, w) x = conv2d(x, weight, stride=1, padding=self.padding, groups=n) x = x.view(n, self.out_channels, *x.shape[-2:]) return x
def forward(self, x, style): """Forward function. Args: x ([Tensor): Input features with shape of (N, C, H, W). style (Tensor): Style latent with shape of (N, C). Returns: Tensor: Output feature with shape of (N, C, H, W). """ n, c, h, w = x.shape # process style code style = self.style_modulation(style).view(n, 1, c, 1, 1) + self.style_bias # combine weight and style weight = self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) weight = weight * demod.view(n, self.out_channels, 1, 1, 1) weight = weight.view(n * self.out_channels, c, self.kernel_size, self.kernel_size) if self.upsample and not self.deconv2conv: x = x.reshape(1, n * c, h, w) weight = weight.view(n, self.out_channels, c, self.kernel_size, self.kernel_size) weight = weight.transpose(1, 2).reshape(n * c, self.out_channels, self.kernel_size, self.kernel_size) x = conv_transpose2d(x, weight, padding=0, stride=2, groups=n) x = x.reshape(n, self.out_channels, *x.shape[-2:]) x = self.blur(x) elif self.upsample and self.deconv2conv: if self.up_after_conv: x = x.reshape(1, n * c, h, w) x = conv2d(x, weight, padding=self.padding, groups=n) x = x.view(n, self.out_channels, *x.shape[2:4]) if self.with_interp_pad: h_, w_ = x.shape[-2:] up_cfg_ = deepcopy(self.up_config) up_scale = up_cfg_.pop('scale_factor') size_ = (h_ * up_scale + self.interp_pad, w_ * up_scale + self.interp_pad) x = F.interpolate(x, size=size_, **up_cfg_) else: x = F.interpolate(x, **self.up_config) if not self.up_after_conv: h_, w_ = x.shape[-2:] x = x.view(1, n * c, h_, w_) x = conv2d(x, weight, padding=self.padding, groups=n) x = x.view(n, self.out_channels, *x.shape[2:4]) elif self.downsample: x = self.blur(x) x = x.view(1, n * self.in_channels, *x.shape[-2:]) x = conv2d(x, weight, stride=2, padding=0, groups=n) x = x.view(n, self.out_channels, *x.shape[-2:]) else: x = x.view(1, n * c, h, w) x = conv2d(x, weight, stride=1, padding=self.padding, groups=n) x = x.view(n, self.out_channels, *x.shape[-2:]) return x
def test_conv2d_cuda(self): x = self.input.cuda() weight = self.weight.cuda() res = conv2d(x, weight, None, 1, 1) assert res.shape == (1, 1, 32, 32) gradgradcheck(partial(conv2d, weight=weight, padding=1, stride=1), x)