def augmentation_pipeline(image, mask):
    alt_background = TF.to_tensor(Image.open("unicorn.jpg"))

    # doing random crop/pad before anything else means that background
    # replacement will never be cropped/padded. this prevents any 0
    # padding appearing when the image is downscaled then padded back
    # to 218x178
    crop_aug = K.RandomCrop((218,178), pad_if_needed=True)
    rand_scale = torch.rand(1).item()*.4 + .8 # scale between .8x - 1.2x
    new_dimensions = (int(218*rand_scale),int(178*rand_scale))

    # resize first
    augmented = kornia.resize(image.unsqueeze(0), new_dimensions)
    aug_mask = kornia.resize(mask.unsqueeze(0), new_dimensions)

    # crop back to 218x178
    # for some reason the generate_parameters() function for K.RandomCrop
    # always generates the same crop box so I have to concatenate the
    # image and mask so I can crop them both with the same random params
    aug_mask = aug_mask.type(torch.float32)
    concat = crop_aug(torch.cat((augmented,aug_mask), axis=1)).squeeze(0)
    augmented = concat[:3]
    aug_mask = concat[3:]

    # augmentation pipeline:
    #augmented = background_removal(image_tensor, mask)
    #augmented = eye_removal(image_tensor, mask)
    augmented = background_replacement(augmented, alt_background, aug_mask)
    augmented = selective_color_distort(augmented, aug_mask)

    return augmented
示例#2
0
def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg):
    true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:])
    true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:])
    true_msk_lg = true_pha_lg != 0
    true_msk_sm = true_pha_sm != 0
    return F.l1_loss(pred_pha_lg, true_pha_lg) + \
        F.l1_loss(pred_pha_sm, true_pha_sm) + \
        F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \
        F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \
        F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \
        F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \
        F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]),
                   kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs())
def load_img(path: str,
             height: int = 256,
             width: int = 256,
             bits: int = 8,
             plot: bool = False,
             crop_mode: str = "centre-crop",
             save_gt: bool = False,
             **kwargs) -> torch.Tensor:
    img = cv2.imread(path, -1)[:, :, ::-1] / (2**bits - 1)
    img = torch.from_numpy(img.copy()).float().permute(2, 0, 1)

    if crop_mode == "resize-crop":
        # Resize such that shorter side matches corresponding target side
        smaller_side = min(height, width)
        img = kornia.resize(img.unsqueeze(0),
                            smaller_side,
                            align_corners=False).squeeze(0)

    img = kornia.center_crop(img.unsqueeze(0), (height, width),
                             align_corners=False)
    img = img.squeeze(0).permute(1, 2, 0)

    if plot:
        plt.imshow(img)
        plt.show()

    if save_gt:
        cv2.imwrite("gt.png", img.numpy()[:, :, ::-1] * 255.0)

    # H x W x 3
    return img
def random_crop(*imgs):
    w = random.choice(range(256, 512))
    h = random.choice(range(256, 512))
    results = []
    for img in imgs:
        img = kornia.resize(img, (max(h, w), max(h, w)))
        img = kornia.center_crop(img, (h, w))
        results.append(img)
    return results
示例#5
0
    def test_smoke(self, device, dtype):
        inp = torch.rand(1, 3, 3, 4, device=device, dtype=dtype)
        out = kornia.resize(inp, (3, 4))
        assert_close(inp, out, atol=1e-4, rtol=1e-4)

        # 2D
        inp = torch.rand(3, 4, device=device, dtype=dtype)
        out = kornia.resize(inp, (3, 4))
        assert_close(inp, out, atol=1e-4, rtol=1e-4)

        # 3D
        inp = torch.rand(3, 3, 4, device=device, dtype=dtype)
        out = kornia.resize(inp, (3, 4))
        assert_close(inp, out, atol=1e-4, rtol=1e-4)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 3, 4, device=device, dtype=dtype)
        out = kornia.resize(inp, (3, 4))
        assert_close(inp, out, atol=1e-4, rtol=1e-4)
示例#6
0
    def test_one_param_long(self, device, dtype):
        inp = torch.rand(1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, 10, side="long")
        assert out.shape == (1, 3, 10, 4)

        # 2D
        inp = torch.rand(5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, 10, side="long")
        assert out.shape == (10, 4)

        # 3D
        inp = torch.rand(3, 5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, 10, side="long")
        assert out.shape == (3, 10, 4)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, 10, side="long")
        assert out.shape == (1, 2, 3, 2, 1, 3, 10, 4)
示例#7
0
    def test_one_param_horz(self, device, dtype):
        inp = torch.rand(1, 3, 2, 5, device=device, dtype=dtype)
        out = kornia.resize(inp, 10, side="horz")
        assert out.shape == (1, 3, 4, 10)

        # 2D
        inp = torch.rand(1, 3, 2, 5, device=device, dtype=dtype)
        out = kornia.resize(inp, 10, side="horz")
        assert out.shape == (1, 3, 4, 10)

        # 3D
        inp = torch.rand(1, 3, 2, 5, device=device, dtype=dtype)
        out = kornia.resize(inp, 10, side="horz")
        assert out.shape == (1, 3, 4, 10)

        # arbitrary dim
        inp = torch.rand(1, 3, 2, 5, device=device, dtype=dtype)
        out = kornia.resize(inp, 10, side="horz")
        assert out.shape == (1, 3, 4, 10)
示例#8
0
    def test_upsize(self, device, dtype):
        inp = torch.rand(1, 3, 3, 4, device=device, dtype=dtype)
        out = kornia.resize(inp, (6, 8))
        assert out.shape == (1, 3, 6, 8)

        # 2D
        inp = torch.rand(3, 4, device=device, dtype=dtype)
        out = kornia.resize(inp, (6, 8))
        assert out.shape == (6, 8)

        # 3D
        inp = torch.rand(3, 3, 4, device=device, dtype=dtype)
        out = kornia.resize(inp, (6, 8))
        assert out.shape == (3, 6, 8)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 3, 4, device=device, dtype=dtype)
        out = kornia.resize(inp, (6, 8))
        assert out.shape == (1, 2, 3, 2, 1, 3, 6, 8)
示例#9
0
    def test_downsize(self, device, dtype):
        inp = torch.rand(1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, (3, 1))
        assert out.shape == (1, 3, 3, 1)

        # 2D
        inp = torch.rand(5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, (3, 1))
        assert out.shape == (3, 1)

        # 3D
        inp = torch.rand(3, 5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, (3, 1))
        assert out.shape == (3, 3, 1)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, (3, 1))
        assert out.shape == (1, 2, 3, 2, 1, 3, 3, 1)
示例#10
0
    def test_downsizeAA(self, device, dtype):
        inp = torch.rand(1, 3, 10, 8, device=device, dtype=dtype)
        out = kornia.resize(inp, (5, 3), antialias=True)
        assert out.shape == (1, 3, 5, 3)

        # 2D
        inp = torch.rand(10, 8, device=device, dtype=dtype)
        out = kornia.resize(inp, (5, 3), antialias=True)
        assert out.shape == (5, 3)

        # 3D
        inp = torch.rand(3, 10, 8, device=device, dtype=dtype)
        out = kornia.resize(inp, (5, 3), antialias=True)
        assert out.shape == (3, 5, 3)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 10, 8, device=device, dtype=dtype)
        out = kornia.resize(inp, (5, 3), antialias=True)
        assert out.shape == (1, 2, 3, 2, 1, 3, 5, 3)
示例#11
0
    def test_one_param_vert(self, device, dtype):
        inp = torch.rand(1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, 10, align_corners=False, side="vert")
        assert out.shape == (1, 3, 10, 4)

        # 2D
        inp = torch.rand(5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, 10, align_corners=False, side="vert")
        assert out.shape == (10, 4)

        # 3D
        inp = torch.rand(3, 5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, 10, align_corners=False, side="vert")
        assert out.shape == (3, 10, 4)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, 10, align_corners=False, side="vert")
        assert out.shape == (1, 2, 3, 2, 1, 3, 10, 4)
示例#12
0
    def test_one_param(self, device, dtype):
        inp = torch.rand(1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, 10)
        assert out.shape == (1, 3, 25, 10)

        # 2D
        inp = torch.rand(5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, 10)
        assert out.shape == (25, 10)

        # 3D
        inp = torch.rand(3, 5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, 10)
        assert out.shape == (3, 25, 10)

        # arbitrary dim
        inp = torch.rand(1, 2, 3, 2, 1, 3, 5, 2, device=device, dtype=dtype)
        out = kornia.resize(inp, 10)
        assert out.shape == (1, 2, 3, 2, 1, 3, 25, 10)
def random_crop(*imgs):
    H_src, W_src = imgs[0].shape[2:]
    W_tgt = random.choice(range(1024, 2048)) // 4 * 4
    H_tgt = random.choice(range(1024, 2048)) // 4 * 4
    scale = max(W_tgt / W_src, H_tgt / H_src)
    results = []
    for img in imgs:
        img = kornia.resize(img, (int(H_src * scale), int(W_src * scale)))
        img = kornia.center_crop(img, (H_tgt, W_tgt))
        results.append(img)
    return results
示例#14
0
def resize(
    input: torch.Tensor,
    size: Union[int, Tuple[int, int]],
    align_corners: Optional[bool] = None,
    interpolation: str = "bilinear",
    side: str = "short",
) -> torch.Tensor:
    align_corners = parse_align_corners(align_corners, interpolation)
    return kornia.resize(
        input, size, interpolation=interpolation, align_corners=align_corners, side=side
    )
示例#15
0
    def augment(self, image, mask):
        # scale between .85x - 1.4x
        rand_scale = torch.rand(1).item() * .55 + .85
        new_dimensions = (int(218 * rand_scale), int(178 * rand_scale))

        augmented = kornia.resize(image.unsqueeze(0), new_dimensions)
        aug_mask = kornia.resize(mask.unsqueeze(0), new_dimensions)

        aug_mask = aug_mask.type(torch.float32)
        concat = self.crop_aug(torch.cat((augmented, aug_mask),
                                         axis=1)).squeeze(0)
        augmented = concat[:3]
        aug_mask = concat[3:]

        rand_background = TF.to_tensor(
            Image.open(self.images[torch.randint(len(self.images),
                                                 (1, )).item()]))
        augmented = self.background_replacement(augmented, rand_background,
                                                aug_mask)
        augmented = self.selective_color_distort(augmented, aug_mask)
        augmented = self.flip_aug(augmented.unsqueeze(0)).squeeze(0)

        return augmented
示例#16
0
    def rp(self, x):
        rnd = np.random.randint(x.shape[-1], self.max_size)
        x = kornia.resize(x, size=(rnd, rnd))

        h_rem = self.max_size - rnd
        w_rem = self.max_size - rnd

        pad_left = np.random.randint(0, w_rem)
        pad_right = w_rem - pad_left
        pad_top = np.random.randint(0, h_rem)
        pad_bottom = h_rem - pad_top

        x = torch.nn.functional.pad(x,
                                    [pad_left, pad_right, pad_top, pad_bottom],
                                    mode='constant',
                                    value=self.value)
        return x
示例#17
0
    def clip_similarity(self, input):
        if self.config.task == "txt2img":
            image = kornia.resize(input, (224, 224))
            if self.augmentation is not None:
                image = self.augmentation(image)

            image_features = self.CLIP.encode_image(image)

            sim = torch.cosine_similarity(image_features, self.text_features)
        elif self.config.task == "img2txt":
            try:
                text_tokens = clip.tokenize(input).to(self.config.device)
            except:
                return torch.zeros(len(input))
            text_features = self.CLIP.encode_text(text_tokens)

            sim = torch.cosine_similarity(text_features, self.image_features)
        return sim
示例#18
0
 def test_downsizeAA(self, device, dtype):
     inp = torch.rand(1, 3, 10, 8, device=device, dtype=dtype)
     out = kornia.resize(inp, (5, 3), antialias=True)
     assert out.shape == (1, 3, 5, 3)
示例#19
0
 def forward(self, x: torch.Tensor) -> torch.Tensor:
     assert isinstance(x, torch.Tensor)
     assert len(x.shape) == 4, x.shape
     return K.resize(x, self._size)
示例#20
0
import kornia.geometry as KG


def load_timg(file_name):
    """Loads the image with OpenCV and converts to torch.Tensor."""
    assert os.path.isfile(file_name), f"Invalid file {file_name}"  # nosec
    # load image with OpenCV
    img = cv2.imread(file_name, cv2.IMREAD_COLOR)
    # convert image to torch tensor
    tensor = K.image_to_tensor(img, None).float() / 255.
    return K.color.bgr_to_rgb(tensor)


registrator = KG.ImageRegistrator('similarity')

img1 = K.resize(load_timg('/Users/oldufo/datasets/stewart/MR-CT/CT.png'),
                (400, 600))
img2 = K.resize(load_timg('/Users/oldufo/datasets/stewart/MR-CT/MR.png'),
                (400, 600))
model, intermediate = registrator.register(img1,
                                           img2,
                                           output_intermediate_models=True)

video_writer = imageio.get_writer('medical_registration.gif', fps=2)

timg_dst_first = img1.clone()
timg_dst_first[0, 0, :, :] = img2[0, 0, :, :]
video_writer.append_data(K.tensor_to_image((timg_dst_first * 255.).byte()))

with torch.no_grad():
    for m in intermediate:
        timg_dst = KG.homography_warp(img1, m, img2.shape[-2:])
    def infer(self,
              img_path,
              cont_path=None,
              mode=None,
              color=None,
              text=None,
              mask_path=None,
              gt_path=None,
              output_dir=None):
        mode = self.opt.TEST.MODE if mode is None else mode
        text = self.opt.TEST.TEXT if text is None else text

        with torch.no_grad():
            im = Image.open(img_path).convert("RGB")
            im = im.resize((self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))
            im_t = linear_scaling(
                transforms.ToTensor()(im).unsqueeze(0).cuda())

            if gt_path is not None:
                gt = Image.open(gt_path).convert("RGB")
                gt = gt.resize((self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))

            if mask_path is None:
                masks = torch.from_numpy(
                    self.mask_generator.generate(
                        self.opt.DATASET.SIZE,
                        self.opt.DATASET.SIZE)).float().cuda()
            else:
                masks = Image.open(mask_path).convert("L")
                masks = masks.resize(
                    (self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))
                masks = self.tensorize(masks).unsqueeze(0).float().cuda()

            if cont_path is not None:
                assert mode in [1, 5, 6, 7, 8]
                c_im = Image.open(cont_path).convert("RGB")
                c_im = c_im.resize(
                    (self.opt.DATASET.SIZE, self.opt.DATASET.SIZE))
                if mode == 6:
                    c_im = c_im.resize((self.opt.DATASET.SIZE // 8,
                                        self.opt.DATASET.SIZE // 8))
                    c_im_t = self.tensorize(c_im).unsqueeze(0).cuda()
                    r_c_im_t = torch.zeros((1, 3, self.opt.DATASET.SIZE,
                                            self.opt.DATASET.SIZE)).cuda()
                    masks = torch.zeros((1, 1, self.opt.DATASET.SIZE,
                                         self.opt.DATASET.SIZE)).cuda()
                    for i in range(1):
                        coord_x, coord_y = np.random.randint(
                            self.opt.DATASET.SIZE - self.opt.DATASET.SIZE // 8,
                            size=(2, ))
                        r_c_im_t[:, :, coord_x:coord_x + c_im_t.size(2),
                                 coord_y:coord_y + c_im_t.size(3)] = c_im_t
                        masks[:, :, coord_x:coord_x + c_im_t.size(2),
                              coord_y:coord_y +
                              c_im_t.size(3)] = torch.ones_like(c_im_t[0, 0])
                    c_im_t = linear_scaling(r_c_im_t)
                elif mode == 7:
                    mask = self.to_pil(
                        torch.zeros(
                            (self.opt.DATASET.SIZE, self.opt.DATASET.SIZE)))
                    d = ImageDraw.Draw(c_im)
                    d_m = ImageDraw.Draw(mask)
                    font = ImageFont.truetype(self.opt.TEST.FONT,
                                              self.opt.TEST.FONT_SIZE)
                    font_w, font_h = d.textsize(text, font=font)
                    c_w = (self.opt.DATASET.SIZE - font_w) // 2
                    c_h = (self.opt.DATASET.SIZE - font_h) // 2
                    d.text((c_w, c_h),
                           text,
                           font=font,
                           fill=tuple([
                               int(a * 255)
                               for a in COLORS["{}".format(color).upper()]
                           ]))
                    d_m.text((c_w, c_h), text, font=font, fill=255)
                    masks = self.tensorize(mask).unsqueeze(0).float().cuda()
                    c_im_t = linear_scaling(self.tensorize(c_im).cuda())
                elif mode == 8:
                    center_cropper = transforms.CenterCrop(
                        (self.opt.DATASET.SIZE // 2,
                         self.opt.DATASET.SIZE // 2))
                    crop = self.tensorize(center_cropper(c_im))
                    coord_x = coord_y = (self.opt.DATASET.SIZE - 128) // 2
                    r_c_im_t = torch.zeros((1, 3, self.opt.DATASET.SIZE,
                                            self.opt.DATASET.SIZE)).cuda()
                    r_c_im_t[:, :, coord_x:coord_x + 128,
                             coord_y:coord_y + 128] = crop
                    if mask_path is None:
                        tmp = kornia.resize(masks, self.opt.DATASET.SIZE // 2)
                        masks = torch.zeros((1, 1, self.opt.DATASET.SIZE,
                                             self.opt.DATASET.SIZE)).cuda()
                        masks[:, :,
                              coord_x:coord_x + self.opt.DATASET.SIZE // 2,
                              coord_y:coord_y +
                              self.opt.DATASET.SIZE // 2] = tmp
                        tmp = kornia.hflip(tmp)
                        masks[:, :,
                              coord_x:coord_x + self.opt.DATASET.SIZE // 2,
                              coord_y:coord_y +
                              self.opt.DATASET.SIZE // 2] += tmp
                        # tmp = kornia.vflip(tmp)
                        # masks[:, :, coord_x:coord_x + self.opt.DATASET.SIZE // 2, coord_y:coord_y + self.opt.DATASET.SIZE // 2] += tmp
                        masks = torch.clamp(masks, min=0., max=1.)
                    c_im_t = linear_scaling(r_c_im_t)
                else:
                    c_im_t = linear_scaling(
                        transforms.ToTensor()(c_im).unsqueeze(0).cuda())
            else:
                assert mode in [2, 3, 4]
                if mode == 2:
                    c_im_t = linear_scaling(torch.rand_like(im_t))
                elif mode == 3:
                    color = self.opt.TEST.BRUSH_COLOR if color is None else color
                    brush = torch.tensor(
                        list(COLORS["{}".format(color).upper()])).unsqueeze(
                            0).unsqueeze(-1).unsqueeze(-1).cuda()
                    c_im_t = linear_scaling(torch.ones_like(im_t) * brush)
                elif mode == 4:
                    c_im_t = im_t

            if (mask_path is None or mode == 5) and mode != 8:
                smooth_masks = self.mask_smoother(1 - masks) + masks
                smooth_masks = torch.clamp(smooth_masks, min=0., max=1.)
            else:
                smooth_masks = masks

            masked_imgs = c_im_t * smooth_masks + im_t * (1. - smooth_masks)
            pred_masks, neck = self.mpn(masked_imgs)
            pred_masks = pred_masks if mode != 8 else torch.clamp(
                pred_masks * smooth_masks, min=0., max=1.)
            masked_imgs_embraced = masked_imgs * (1. - pred_masks)
            output = self.rin(masked_imgs_embraced, pred_masks, neck)
            output = torch.clamp(output, max=1., min=0.)

            if output_dir is not None:
                # output_dir = os.path.join(output_dir, self.ablation_map[mode])
                # os.makedirs(output_dir, exist_ok=True)
                if mode == 8:
                    self.to_pil(
                        torch.cat([
                            linear_unscaling(im_t).squeeze().cpu(),
                            self.tensorize(c_im).squeeze().cpu(),
                            linear_unscaling(masked_imgs).squeeze().cpu(),
                            output.squeeze().cpu()
                        ],
                                  dim=2)).save(
                                      os.path.join(
                                          output_dir, "out{}_{}_{}".format(
                                              mode, color,
                                              img_path.split("/")[-1])))
                else:
                    self.to_pil(
                        torch.cat([
                            linear_unscaling(masked_imgs).squeeze().cpu(),
                            output.squeeze().cpu()
                        ],
                                  dim=1)).save(
                                      os.path.join(
                                          output_dir, "out{}_{}_{}".format(
                                              mode, color,
                                              img_path.split("/")[-1])))
            else:
                self.to_pil(output.squeeze().cpu()).show()
                self.to_pil(pred_masks.squeeze().cpu()).show()
                self.to_pil(
                    linear_unscaling(masked_imgs).squeeze().cpu()).show()
                self.to_pil(smooth_masks.squeeze().cpu()).show()
                self.to_pil(linear_unscaling(im_t).squeeze().cpu()).show()
                if gt_path is not None:
                    gt.show()
示例#22
0
 def test_one_param_long(self, device, dtype):
     inp = torch.rand(1, 3, 5, 2, device=device, dtype=dtype)
     out = kornia.resize(inp, 10, side="long")
     assert out.shape == (1, 3, 10, 4)
示例#23
0
 def test_smoke(self, device, dtype):
     inp = torch.rand(1, 3, 3, 4, device=device, dtype=dtype)
     out = kornia.resize(inp, (3, 4))
     assert_allclose(inp, out, atol=1e-4, rtol=1e-4)
示例#24
0
 def test_one_param_vert(self, device):
     inp = torch.rand(1, 3, 5, 2).to(device)
     out = kornia.resize(inp, 10, side="vert")
     assert out.shape == (1, 3, 10, 4)
示例#25
0
 def test_one_param_horz(self, device):
     inp = torch.rand(1, 3, 2, 5).to(device)
     out = kornia.resize(inp, 10, side="horz")
     assert out.shape == (1, 3, 4, 10)
示例#26
0
 def test_one_param(self, device):
     inp = torch.rand(1, 3, 5, 2).to(device)
     out = kornia.resize(inp, 10)
     assert out.shape == (1, 3, 25, 10)
示例#27
0
 def test_downsize(self, device):
     inp = torch.rand(1, 3, 5, 2).to(device)
     out = kornia.resize(inp, (3, 1))
     assert out.shape == (1, 3, 3, 1)
示例#28
0
 def test_upsize(self, device):
     inp = torch.rand(1, 3, 3, 4).to(device)
     out = kornia.resize(inp, (6, 8))
     assert out.shape == (1, 3, 6, 8)
示例#29
0
 def test_smoke(self, device):
     inp = torch.rand(1, 3, 3, 4).to(device)
     out = kornia.resize(inp, (3, 4))
     assert_allclose(inp, out)
# load example image. the image is resized because DeepLab uses
# a lot of dilated convolutions and doesn't work very well for
# low resolution images.
image = Image.open("nate.jpg")
scaled_image = image.resize((418, 512), resample=Image.LANCZOS)
image_tensor = TF.to_tensor(scaled_image)

# send the input through the network. unsqueeze is used to
# add a batch dimension, because torch always expects a batch
# but in this case it's just one image
# I then use Kornia to resize the mask back to 218x178 then
# squeeze to remove the batch channel again (kornia also
# always expects a batch dimension)
with torch.no_grad():
    mask_large = network(image_tensor.unsqueeze(0).to(device))
mask = kornia.resize(mask_large, (218, 178)).squeeze(0)


# this function saves the mask as 18 grayscale JPG images
# (one for each channel). It takes the torch tensor as input
def save_mask(mask, image_name, folder):
    mask = np.uint8(np.array(torch.sigmoid(mask)) * 255)
    for i, channel in enumerate(mask):
        channel_im = Image.fromarray(channel, mode='L')
        channel_im.save(f"{folder}/{image_name}_{i}.jpg")


save_mask(mask, "nate", "output")


# finally, I have a function for plotting the mask as a single