def forward(self, input): if self.dropout_fn is not None: dropped_w = self.dropout_fn.forward(self.weight, self.training) else: dropped_w = self.weight if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) return F.conv2d(F.pad(input, expanded_padding, mode='circular'), dropped_w, self.bias, self.stride, _pair(0), self.dilation, self.groups) return F.conv2d(input, dropped_w, self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv_ws_2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, eps=1e-5): """Conv2d with weight standarlization. :param input: input feature map :type input: torch.Tensor :param weight: weight of conv layer :type weight: torch.Tensor :param bias: bias :type bias: torch.Tensor :param stride: conv stride :type stride: int :param padding: num of padding :type padding: int :param dilation: num of dilation :type dilation: int :param groups: num of group :type groups: int :param eps: weight eps :type eps: float :return: feature map after weight standarlization :rtype: torch.Tensor """ c_in = weight.size(0) weight_flat = weight.view(c_in, -1) mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1) std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1) weight = (weight - mean) / (std + eps) return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
def apply_cdna_kernel(image, kernel): """ Inputs: image -> tensor[B, C, H, W] kernal -> tensor[B, N, K, K] Outputs: new_image -> tensor[B, N, C, H, W] """ batch_size = image.shape[0] image_channel = image.shape[1] image_height = image.shape[2] image_width = image.shape[3] num_kernel = kernel.shape[1] kernel_size = kernel.shape[2] padding = kernel_size // 2 _image = image.transpose(0, 1) # [C, B, H, W] _kernel = kernel.view(batch_size * num_kernel, 1, kernel_size, kernel_size) # [B * N, 1, K, K] output = F.conv2d(_image, _kernel, stride=1, padding=padding, groups=batch_size) # [C, B * N, H, W] output = output.view(image_channel, batch_size, num_kernel, image_height, image_width) # [C, B, N, H, W] return output.permute(1, 2, 0, 3, 4) # [B, N, C, H, W]
def forward(self, input): out = F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) return out
def _ssim(img1, img2, window, window_size, channel, size_average = True): mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1*mu2 sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 C1 = 0.01**2 C2 = 0.03**2 ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() else: return ssim_map.mean(1).mean(1).mean(1)
def blur(self, tensor_image): kernel = [[0., 1., 1.], [0., 3., 1.], [0., 0., 0.]] kernel_1 = [[1., 0., 1.], [0., 0, 0.], [1., 0., 1.]] min_batch = tensor_image.size()[0] channels = tensor_image.size()[1] out_channel = channels kernel = torch.tensor(kernel_1).expand(out_channel, channels, 3, 3) weight = nn.Parameter(data=kernel, requires_grad=False) return F.conv2d(tensor_image, weight, padding=1)
def forward(self, input, eps=1e-4): weight = self.standardize_weight(eps) return F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def forward(self, x): return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
def forward(self, img): img = F.conv2d(img, self.weight, None, stride=self.stride, padding=self.padding) img[0] += self.bias.view(-1, 1, 1) return img
def forward(self, input, style): batch, in_channel, height, width = input.shape # Modulate the style style = self.modulation(style).view(batch, 1, in_channel, 1, 1) # Adding style modulation to weight weight = self.scale * self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) #A good explenation of modulation/demodulation https://youtu.be/MYCTn80qSk0?t=142 weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size) if self.up_sample: # Reshape input for deconvolution input = input.view(1, batch * in_channel, height, width) # Reshape weights for deconvolution weight = weight.view(batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size) weight = weight.transpose(1, 2).reshape(batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size) # Upsampling convolution out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape # Reshape output to initial shape out = out.view(batch, self.out_channel, height, width) # Blur the output out = self.blur(out) elif self.down_sample: # blur image before convolution input = self.blur(input) # Reshape input for convolution _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) # Downsampling convolution out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape # Reshape output to initial shape out = out.view(batch, self.out_channel, height, width) else: # Reshape input for convolution input = input.view(1, batch * in_channel, height, width) # Downsampling convolution out = F.conv2d(input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape # Reshape output to initial shape out = out.view(batch, self.out_channel, height, width) return out