def forward(self, input, style): """ Return, the transformed x. Parameters ---------- x: pytorch tensor. for appearance latent space. style: pytorch tensor. for attribute editing latent space. Returns ------- the transformed x. """ batch, in_channel, height, width = input.shape style = self.modulation(style).view(batch, 1, in_channel, 1, 1) weight = self.scale * self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) weight = weight.view( batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size ) if self.upsample: input = input.view(1, batch * in_channel, height, width) weight = weight.view( batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size ) weight = weight.transpose(1, 2).reshape( batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size ) out = conv2d_gradfix.conv_transpose2d( input, weight, padding=0, stride=2, groups=batch ) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) out = self.blur(out) elif self.downsample: input = self.blur(input) _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) out = conv2d_gradfix.conv2d( input, weight, padding=0, stride=2, groups=batch ) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) else: input = input.view(1, batch * in_channel, height, width) out = conv2d_gradfix.conv2d( input, weight, padding=self.padding, groups=batch ) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) return out
def forward(self, input): out = conv2d_gradfix.conv2d( input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, ) return out
def forward(self, input): """ Return, the convolutioned x. Parameters ---------- x: pytorch tensor, used for the input of convolution Returns ------- the convolutioned x """ out = conv2d_gradfix.conv2d( input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, ) return out
def forward(self, input, style): batch, in_channel, height, width = input.shape if not self.fused: weight = self.scale * self.weight.squeeze(0) style = self.modulation(style) if self.demodulate: w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1) dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() input = input * style.reshape(batch, in_channel, 1, 1) if self.upsample: weight = weight.transpose(0, 1) out = conv2d_gradfix.conv_transpose2d(input, weight, padding=0, stride=2) out = self.blur(out) elif self.downsample: input = self.blur(input) out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) else: out = conv2d_gradfix.conv2d(input, weight, padding=self.padding) if self.demodulate: out = out * dcoefs.view(batch, -1, 1, 1) return out style = self.modulation(style).view(batch, 1, in_channel, 1, 1) weight = self.scale * self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size) if self.upsample: input = input.view(1, batch * in_channel, height, width) weight = weight.view(batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size) weight = weight.transpose(1, 2).reshape(batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size) out = conv2d_gradfix.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) out = self.blur(out) elif self.downsample: input = self.blur(input) _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) else: input = input.view(1, batch * in_channel, height, width) out = conv2d_gradfix.conv2d(input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) return out