Example #1
0
    def visualize(self, layer, filt, lr=1e-2, opt_steps=36):
        sz = self.size
        img = Image.fromarray(
            np.uint8(np.random.uniform(150, 180, (sz, sz, 3))))
        activations = SaveFeatures(list(self.target.children())[layer])
        gaussian_filter = get_gaussian_kernel()
        self.model.zero_grad()
        for outer in tqdm(range(self.upscaling_steps), leave=False):
            img_var = torch.unsqueeze(ToTensor()(img),
                                      0).cuda(ORDINAL).requires_grad_(True)
            img_var.requires_grad_(True).cuda(ORDINAL)
            optimizer = torch.optim.Adam([img_var], lr=lr, weight_decay=1e-6)

            pbar = tqdm(range(opt_steps), leave=False)
            for n in pbar:
                optimizer.zero_grad()
                self.model(img_var)
                loss = -activations.features[
                    0, filt].mean() + 0.00 * torch.norm(img_var)
                loss.backward()
                pbar.set_description(f'Loss: {loss.item()}')
                optimizer.step()

            sz = int(sz * self.upscaling_factor)
            img = ToPILImage()(img_var.squeeze(0))
            if outer != self.upscaling_steps:
                img = img.resize((sz, sz))
                img = img.filter(ImageFilter.BoxBlur(2))
            self.output = img.copy()
        self.save(layer, filt)
        activations.close()
Example #2
0
    def __getitem__(self, index):
        input = load_img(self.image_filenames[index])  #input是预先合成的4通道RGDB图片
        #数据增强
        if self.crop:
            input = RandomCrop(64)(input)  #取patch
            input = RandomHorizontalFlip()(input)  #水平翻转
            input = RandomVerticalFlip()(input)  #竖直翻转
            input = RandomRotation(180)(input)  #随机旋转
        input_tensor = ToTensor()(input)
        rgb_tensor = torch.zeros(3, input_tensor.shape[1],
                                 input_tensor.shape[2])
        depth_tensor = torch.zeros(1, input_tensor.shape[1],
                                   input_tensor.shape[2])
        rgb_tensor[0, :, :] = input_tensor[0, :, :]
        rgb_tensor[1, :, :] = input_tensor[1, :, :]
        rgb_tensor[2, :, :] = input_tensor[2, :, :]
        depth_tensor[0, :, :] = input_tensor[3, :, :]
        depth = ToPILImage()(depth_tensor)
        size = min(depth.size[0], depth.size[1])
        guide = ToPILImage()(rgb_tensor)
        target = depth.copy()

        guide = guide.convert('L')
        #生成LR
        depth = downsampling(depth, self.upscale_factor)
        depth = Resize(size=size, interpolation=Image.BICUBIC)(depth)

        depth = ToTensor()(depth)
        guide = ToTensor()(guide)
        depth = torch.cat((depth, guide), 0)  #concatenate 生成输入张量
        target = ToTensor()(target)

        return depth, target
def create_image(net, loader, name, upscaling=2, multiscale=False):
    patch_size = 30
    patches = np.array(
        [[85, 90, 85 + patch_size, 90 + patch_size],
         [160, 140, 160 + patch_size, 140 + patch_size],
         [350, 80, 350 + patch_size, 80 + patch_size]])

    fig = plt.figure(figsize=(15, 18), dpi=100)
    gs = fig.add_gridspec(4, 3, wspace=0.01, hspace=0.3, left=0.05,
                          top=0.95, bottom=0.02, right=0.95)

    highres, lowres = loader[0]
    lowres = lowres.to(torch.device("cuda:0"))
    lowres = lowres.view([1] + list(lowres.size()))
    net.eval()
    with torch.no_grad():
        if (multiscale):
            superres = net(lowres, upscaling)
        else:
            superres = net(lowres)
    lowres = ToImage()(lowres.cpu().view(lowres.size()[1:]))
    superres = ToImage()(superres.cpu().view(superres.size()[1:]))
    highres = ToImage()(highres.cpu().view(highres.size()[1:]))

    lowres_draw = lowres.copy()
    draw = ImageDraw.Draw(lowres_draw)
    draw.rectangle(list(patches[0]), outline='white')
    draw.rectangle(list(patches[1]), outline='white')
    draw.rectangle(list(patches[2]), outline='white')

    lowres = lowres.resize((lowres.size[0] * upscaling,
                            lowres.size[1] * upscaling),
                           Image.BICUBIC)

    psnr_low, psnr_super, ssim_low, ssim_super = compute_metrics(
        net, DataLoader(loader), upscaling, multiscale)
    lowres_title = "Low-resolution image" +\
    "\nAverage PSNR over the dataset: {:.2f}\n".format(psnr_low) +\
    "Average SSIM over the dataset: {:.4f}".format(ssim_low)
    superres_title = "Reconstructed image\n" +\
    "Average PSNR over the dataset: {:.2f}\n".format(psnr_super) +\
    "Average SSIM over the dataset: {:.4f}".format(ssim_super)
    fig.add_subplot(gs[0, 0], xticks=[], yticks=[],
                    ylabel=f"Image", title=lowres_title)
    plt.imshow(np.array(lowres_draw))
    fig.add_subplot(gs[0, 1], xticks=[], yticks=[],
                    title=superres_title)
    plt.imshow(np.array(superres))
    fig.add_subplot(gs[0, 2], xticks=[], yticks=[],
                    title="High-resolution image")
    plt.imshow(np.array(highres))

    ylabels = ["Patch 1", "Patch 2", "Patch 3"]
    for i in range(3):
        print(lowres.size, highres.size, superres.size)
        lowres_patch = lowres.crop(patches[i] * upscaling)
        highres_patch = highres.crop(patches[i] * upscaling)
        superres_patch = superres.crop(patches[i] * upscaling)

        fig.add_subplot(gs[1 + i, 0], xticks=[], yticks=[],
                        ylabel=ylabels[i])
        plt.imshow(np.array(lowres_patch))
        fig.add_subplot(gs[1 + i, 1], xticks=[], yticks=[])
        plt.imshow(np.array(superres_patch))
        fig.add_subplot(gs[1 + i, 2], xticks=[], yticks=[])
        plt.imshow(np.array(highres_patch))

    plt.savefig(name)