def forward(self, x): if self.filter_type == 'Median_Filter': x_denoised = kornia.median_blur(x, (self.ksize, self.ksize)) elif self.filter_type == 'Mean_Filter': x_denoised = kornia.box_blur(x, (self.ksize, self.ksize)) elif self.filter_type == 'Gaussian_Filter': x_denoised = kornia.gaussian_blur2d( x, (self.ksize, self.ksize), (0.3 * ((x.shape[3] - 1) * 0.5 - 1) + 0.8, 0.3 * ((x.shape[2] - 1) * 0.5 - 1) + 0.8)) new_x = x + self.conv(x_denoised) return new_x
def forward(self, x): if self.filter_type == 'Median_Filter': x_denoised = kornia.median_blur(x, (self.ksize, self.ksize)) elif self.filter_type == 'Mean_Filter': x_denoised = kornia.box_blur(x, (self.ksize, self.ksize)) elif self.filter_type == 'Gaussian_Filter': x_denoised = kornia.gaussian_blur2d( x, (self.ksize, self.ksize), (0.3 * ((x.shape[3] - 1) * 0.5 - 1) + 0.8, 0.3 * ((x.shape[2] - 1) * 0.5 - 1) + 0.8)) elif self.filter_type == "NonLocal_Filter": x_denoised = self.non_local_op(x, self.embed, self.softmax) new_x = x + self.conv_1x1(x_denoised) return new_x
def main(): try: img_bgr: np.ndarray = cv2.imread('model8.png', cv2.IMREAD_COLOR) x_bgr: torch.Tensor = kornia.image_to_tensor(img_bgr) x_rgb: torch.Tensor = kornia.bgr_to_rgb(x_bgr) x_rgb = x_rgb.expand(2, -1, -1, -1) x_rgb = x_rgb.float() / 255. imshow(x_rgb) # Box Blur x_blur: torch.Tensor = kornia.box_blur(x_rgb, (9, 9)) imshow(x_blur) # Median Blur x_blur: torch.Tensor = kornia.median_blur(x_rgb, (5, 5)) imshow(x_blur) # Gaussian Blur x_blur: torch.Tensor = kornia.gaussian_blur2d(x_rgb, (11, 11), (11., 11.)) imshow(x_blur) except: print("Error found")
# Create batch and normalize x_rgb = x_rgb.expand(2, -1, -1, -1) # 4xCxHxW x_rgb = x_rgb.float() / 255. def imshow(input: torch.Tensor): out: torch.Tensor = torchvision.utils.make_grid(input, nrow=2, padding=1) out_np: np.array = kornia.tensor_to_image(out) plt.imshow(out_np) plt.axis('off') ############################# # Show original imshow(x_rgb) ############################# # Box Blur x_blur: torch.Tensor = kornia.box_blur(x_rgb, (9, 9)) imshow(x_blur) ############################# # Media Blur x_blur: torch.Tensor = kornia.median_blur(x_rgb, (5, 5)) imshow(x_blur) ############################# # Gaussian Blur x_blur: torch.Tensor = kornia.gaussian_blur2d(x_rgb, (11, 11), (11., 11.)) imshow(x_blur)