Beispiel #1
0
    def forward(self, x, style, input_gain=None):
        n, c, h, w = x.shape

        weight = self.weight
        # Pre-normalize inputs to avoid FP16 overflow.
        if x.dtype == torch.float16 and self.demodulate:
            weight = weight * (
                1 / np.sqrt(
                    self.in_channels * self.kernel_size * self.kernel_size) /
                weight.norm(float('inf'), dim=[1, 2, 3], keepdim=True)
            )  # max_Ikk
            style = style / style.norm(
                float('inf'), dim=1, keepdim=True)  # max_I

        # process style code
        style = self.style_modulation(style).view(n, 1, c, 1,
                                                  1) + self.style_bias
        # combine weight and style
        weight = weight * style
        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
            weight = weight * demod.view(n, self.out_channels, 1, 1, 1)

        if input_gain is not None:
            # input_gain shape [batch, in_ch]
            input_gain = input_gain.expand(n, self.in_channels)
            # weight shape [batch, out_ch, in_ch, kernel_size, kernel_size]
            weight = weight * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4)

        weight = weight.view(n * self.out_channels, c, self.kernel_size,
                             self.kernel_size)

        weight = weight.to(x.dtype)
        if self.upsample:
            x = x.reshape(1, n * c, h, w)
            weight = weight.view(n, self.out_channels, c, self.kernel_size,
                                 self.kernel_size)
            weight = weight.transpose(1, 2).reshape(n * c, self.out_channels,
                                                    self.kernel_size,
                                                    self.kernel_size)
            x = conv_transpose2d(x, weight, padding=0, stride=2, groups=n)
            x = x.reshape(n, self.out_channels, *x.shape[-2:])
            x = self.blur(x)

        elif self.downsample:
            x = self.blur(x)
            x = x.view(1, n * self.in_channels, *x.shape[-2:])
            x = conv2d(x, weight, stride=2, padding=0, groups=n)
            x = x.view(n, self.out_channels, *x.shape[-2:])
        else:
            x = x.reshape(1, n * c, h, w)
            x = conv2d(x, weight, stride=1, padding=self.padding, groups=n)
            x = x.view(n, self.out_channels, *x.shape[-2:])
        return x
Beispiel #2
0
    def forward(self, x, style):
        """Forward function.

        Args:
            x ([Tensor): Input features with shape of (N, C, H, W).
            style (Tensor): Style latent with shape of (N, C).

        Returns:
            Tensor: Output feature with shape of (N, C, H, W).
        """
        n, c, h, w = x.shape
        # process style code
        style = self.style_modulation(style).view(n, 1, c, 1,
                                                  1) + self.style_bias

        # combine weight and style
        weight = self.weight * style
        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
            weight = weight * demod.view(n, self.out_channels, 1, 1, 1)

        weight = weight.view(n * self.out_channels, c, self.kernel_size,
                             self.kernel_size)

        if self.upsample and not self.deconv2conv:
            x = x.reshape(1, n * c, h, w)
            weight = weight.view(n, self.out_channels, c, self.kernel_size,
                                 self.kernel_size)
            weight = weight.transpose(1, 2).reshape(n * c, self.out_channels,
                                                    self.kernel_size,
                                                    self.kernel_size)
            x = conv_transpose2d(x, weight, padding=0, stride=2, groups=n)
            x = x.reshape(n, self.out_channels, *x.shape[-2:])
            x = self.blur(x)
        elif self.upsample and self.deconv2conv:
            if self.up_after_conv:
                x = x.reshape(1, n * c, h, w)
                x = conv2d(x, weight, padding=self.padding, groups=n)
                x = x.view(n, self.out_channels, *x.shape[2:4])

            if self.with_interp_pad:
                h_, w_ = x.shape[-2:]
                up_cfg_ = deepcopy(self.up_config)
                up_scale = up_cfg_.pop('scale_factor')
                size_ = (h_ * up_scale + self.interp_pad,
                         w_ * up_scale + self.interp_pad)
                x = F.interpolate(x, size=size_, **up_cfg_)
            else:
                x = F.interpolate(x, **self.up_config)

            if not self.up_after_conv:
                h_, w_ = x.shape[-2:]
                x = x.view(1, n * c, h_, w_)
                x = conv2d(x, weight, padding=self.padding, groups=n)
                x = x.view(n, self.out_channels, *x.shape[2:4])

        elif self.downsample:
            x = self.blur(x)
            x = x.view(1, n * self.in_channels, *x.shape[-2:])
            x = conv2d(x, weight, stride=2, padding=0, groups=n)
            x = x.view(n, self.out_channels, *x.shape[-2:])
        else:
            x = x.view(1, n * c, h, w)
            x = conv2d(x, weight, stride=1, padding=self.padding, groups=n)
            x = x.view(n, self.out_channels, *x.shape[-2:])

        return x
 def test_conv2d_cuda(self):
     x = self.input.cuda()
     weight = self.weight.cuda()
     res = conv2d(x, weight, None, 1, 1)
     assert res.shape == (1, 1, 32, 32)
     gradgradcheck(partial(conv2d, weight=weight, padding=1, stride=1), x)